diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 1cde4351546b26..da348aca973905 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -8,7 +8,49 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { + public static void Abs(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.INumberBase { } + public static void AddMultiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan multiplier, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void AddMultiply(System.ReadOnlySpan x, System.ReadOnlySpan y, T multiplier, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void AddMultiply(System.ReadOnlySpan x, T y, System.ReadOnlySpan multiplier, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void Add(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { } + public static void Add(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { } public static void ConvertToHalf(System.ReadOnlySpan source, System.Span destination) { throw null; } public static void ConvertToSingle(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static void Cosh(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IHyperbolicFunctions { } + public static T CosineSimilarity(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : System.Numerics.IRootFunctions { throw null; } + public static T Distance(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : System.Numerics.IRootFunctions { throw null; } + public static void Divide(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IDivisionOperators { } + public static void Divide(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IDivisionOperators { } + public static T Dot(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } + public static void Exp(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IExponentialFunctions { } + public static void Log2(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.ILogarithmicFunctions { } + public static void Log(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.ILogarithmicFunctions { } + public static T MaxMagnitude(System.ReadOnlySpan x) where T : System.Numerics.INumberBase { throw null; } + public static void MaxMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.INumberBase { } + public static T Max(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } + public static void Max(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.INumber { } + public static T MinMagnitude(System.ReadOnlySpan x) where T : System.Numerics.INumberBase { throw null; } + public static void MinMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.INumberBase { } + public static T Min(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } + public static void Min(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.INumber { } + public static void MultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan addend, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void MultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, T addend, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void MultiplyAdd(System.ReadOnlySpan x, T y, System.ReadOnlySpan addend, System.Span destination) where T : System.Numerics.IAdditionOperators, System.Numerics.IMultiplyOperators { } + public static void Multiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { } + public static void Multiply(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { } + public static void Negate(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IUnaryNegationOperators { } + public static T Norm(System.ReadOnlySpan x) where T : System.Numerics.IRootFunctions { throw null; } + public static T ProductOfDifferences(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : System.Numerics.ISubtractionOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } + public static T ProductOfSums(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } + public static T Product(System.ReadOnlySpan x) where T : System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } + public static void Sigmoid(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IExponentialFunctions { } + public static void Sinh(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IHyperbolicFunctions { } + public static void SoftMax(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IExponentialFunctions { } + public static void Subtract(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.ISubtractionOperators { } + public static void Subtract(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.ISubtractionOperators { } + public static T SumOfMagnitudes(System.ReadOnlySpan x) where T : System.Numerics.INumberBase { throw null; } + public static T SumOfSquares(System.ReadOnlySpan x) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IMultiplyOperators { throw null; } + public static T Sum(System.ReadOnlySpan x) where T : System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } + public static void Tanh(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IHyperbolicFunctions { } } } diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 86b9f4d82b1f61..def219f3544820 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -129,4 +129,7 @@ The destination span may only overlap with an input span if the two spans start at the same memory location. + + Negating the minimum value of a twos complement number is invalid. + diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index cb1e66fdea1c06..ae18ca1b2d9398 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -9,16 +9,19 @@ - + + - + + + - + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Helpers.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Helpers.cs new file mode 100644 index 00000000000000..4b5f40cec39d0f --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Helpers.cs @@ -0,0 +1,384 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Numerics.Tensors +{ + /// Performs primitive tensor operations over spans of memory. + public static partial class TensorPrimitives + { + /// Throws an exception if the and spans overlap and don't begin at the same memory location. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) + { + if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && + input.Overlaps(output)) + { + ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); + } + } + + /// Throws an for trying to negate the minimum value of a two-complement value. + internal static void ThrowNegateTwosCompOverflow() => throw new OverflowException(SR.Overflow_NegateTwosCompNum); + + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 64 rows of 64 bytes. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 65 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 64 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentByteMask_64x65 => + [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + ]; + + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 32 rows of 32 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 33 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 32 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt16Mask_32x33 => + [ + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + ]; + + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt32Mask_16x17 => + [ + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + ]; + + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 8 rows of 8 ulongs. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 9 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 8 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt64Mask_8x9 => + [ + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 64 rows of 64 ushorts. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 65 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderByteMask_64x65 => + [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 32 rows of 32 ushorts. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 33 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt16Mask_32x33 => + [ + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt32Mask_16x17 => + [ + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 8 rows of 8 ulongs. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 9 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 8 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt64Mask_8x9 => + [ + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, + ]; + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Single.cs similarity index 84% rename from src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs rename to src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Single.cs index 88094042e62c9f..0ad05d15286d97 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Single.cs @@ -27,7 +27,7 @@ public static partial class TensorPrimitives /// /// public static void Abs(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -46,7 +46,7 @@ public static void Abs(ReadOnlySpan x, Span destination) => /// /// public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -63,7 +63,7 @@ public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span /// /// public static void Add(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. @@ -84,7 +84,7 @@ public static void Add(ReadOnlySpan x, float y, Span destination) /// /// public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. @@ -104,7 +104,7 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, Rea /// /// public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. @@ -124,7 +124,7 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, flo /// /// public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); /// Computes the element-wise hyperbolic cosine of each single-precision floating-point radian angle in the specified tensor. /// The tensor, represented as a span. @@ -148,7 +148,7 @@ public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan /// public static void Cosh(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -203,7 +203,7 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y) ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return MathF.Sqrt(Aggregate(x, y)); + return MathF.Sqrt(Aggregate(x, y)); } /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. @@ -223,7 +223,7 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y) /// /// public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -240,7 +240,7 @@ public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span /// public static void Divide(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the dot product of two tensors containing single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -266,7 +266,7 @@ public static void Divide(ReadOnlySpan x, float y, Span destinatio /// /// public static float Dot(ReadOnlySpan x, ReadOnlySpan y) => - Aggregate(x, y); + Aggregate(x, y); /// Computes the element-wise result of raising e to the single-precision floating-point number powers in the specified tensor. /// The tensor, represented as a span. @@ -287,7 +287,7 @@ public static float Dot(ReadOnlySpan x, ReadOnlySpan y) => /// /// public static void Exp(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Searches for the index of the largest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. @@ -303,7 +303,7 @@ public static void Exp(ReadOnlySpan x, Span destination) => /// /// public static int IndexOfMax(ReadOnlySpan x) => - IndexOfMinMaxCore(x); + IndexOfMinMaxCore(x); /// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor. /// The tensor, represented as a span. @@ -320,7 +320,7 @@ public static int IndexOfMax(ReadOnlySpan x) => /// /// public static int IndexOfMaxMagnitude(ReadOnlySpan x) => - IndexOfMinMaxCore(x); + IndexOfMinMaxCore(x); /// Searches for the index of the smallest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. @@ -336,7 +336,7 @@ public static int IndexOfMaxMagnitude(ReadOnlySpan x) => /// /// public static int IndexOfMin(ReadOnlySpan x) => - IndexOfMinMaxCore(x); + IndexOfMinMaxCore(x); /// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor. /// The tensor, represented as a span. @@ -353,7 +353,7 @@ public static int IndexOfMin(ReadOnlySpan x) => /// /// public static int IndexOfMinMagnitude(ReadOnlySpan x) => - IndexOfMinMaxCore(x); + IndexOfMinMaxCore(x); /// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor. /// The tensor, represented as a span. @@ -376,7 +376,7 @@ public static int IndexOfMinMagnitude(ReadOnlySpan x) => /// /// public static void Log(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Computes the element-wise base 2 logarithm of single-precision floating-point numbers in the specified tensor. /// The tensor, represented as a span. @@ -399,7 +399,7 @@ public static void Log(ReadOnlySpan x, Span destination) => /// /// public static void Log2(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Searches for the largest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. @@ -416,7 +416,7 @@ public static void Log2(ReadOnlySpan x, Span destination) => /// /// public static float Max(ReadOnlySpan x) => - MinMaxCore(x); + MinMaxCore(x); /// Computes the element-wise maximum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -440,7 +440,7 @@ public static float Max(ReadOnlySpan x) => /// /// public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Searches for the single-precision floating-point number with the largest magnitude in the specified tensor. /// The tensor, represented as a span. @@ -458,7 +458,7 @@ public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span /// /// public static float MaxMagnitude(ReadOnlySpan x) => - MinMaxCore(x); + MinMaxCore(x); /// Computes the element-wise single-precision floating-point number with the largest magnitude in the specified tensors. /// The first tensor, represented as a span. @@ -476,7 +476,7 @@ public static float MaxMagnitude(ReadOnlySpan x) => /// /// public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Searches for the smallest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. @@ -493,7 +493,7 @@ public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Sp /// /// public static float Min(ReadOnlySpan x) => - MinMaxCore(x); + MinMaxCore(x); /// Computes the element-wise minimum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -517,7 +517,7 @@ public static float Min(ReadOnlySpan x) => /// /// public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Searches for the single-precision floating-point number with the smallest magnitude in the specified tensor. /// The tensor, represented as a span. @@ -535,7 +535,7 @@ public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span /// /// public static float MinMagnitude(ReadOnlySpan x) => - MinMaxCore(x); + MinMaxCore(x); /// Computes the element-wise single-precision floating-point number with the smallest magnitude in the specified tensors. /// The first tensor, represented as a span. @@ -558,7 +558,7 @@ public static float MinMagnitude(ReadOnlySpan x) => /// /// public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -577,7 +577,7 @@ public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Sp /// /// public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -595,7 +595,7 @@ public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span /// public static void Multiply(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -616,7 +616,7 @@ public static void Multiply(ReadOnlySpan x, float y, Span destinat /// /// public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -637,7 +637,7 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, Rea /// /// public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -657,7 +657,7 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, flo /// /// public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); /// Computes the element-wise negation of each single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. @@ -673,7 +673,7 @@ public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan /// public static void Negate(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Computes the Euclidean norm of the specified tensor of single-precision floating-point numbers. /// The first tensor, represented as a span. @@ -715,7 +715,7 @@ public static float Product(ReadOnlySpan x) ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return Aggregate(x); + return Aggregate(x); } /// Computes the product of the element-wise differences of the single-precision floating-point numbers in the specified non-empty tensors. @@ -746,7 +746,7 @@ public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan(x, y); + return Aggregate(x, y); } /// Computes the product of the element-wise sums of the single-precision floating-point numbers in the specified non-empty tensors. @@ -777,7 +777,7 @@ public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return Aggregate(x, y); + return Aggregate(x, y); } /// Computes the element-wise sigmoid function on the specified non-empty tensor of single-precision floating-point numbers. @@ -802,7 +802,7 @@ public static void Sigmoid(ReadOnlySpan x, Span destination) ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); } /// Computes the element-wise hyperbolic sine of each single-precision floating-point radian angle in the specified tensor. @@ -827,7 +827,7 @@ public static void Sigmoid(ReadOnlySpan x, Span destination) /// /// public static void Sinh(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + InvokeSpanIntoSpan(x, destination); /// Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. @@ -859,9 +859,9 @@ public static void SoftMax(ReadOnlySpan x, Span destination) ValidateInputOutputSpanNonOverlapping(x, destination); - float expSum = Aggregate(x); + float expSum = Aggregate(x); - InvokeSpanScalarIntoSpan(x, expSum, destination); + InvokeSpanScalarIntoSpan(x, expSum, destination); } /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. @@ -881,7 +881,7 @@ public static void SoftMax(ReadOnlySpan x, Span destination) /// /// public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. @@ -898,7 +898,7 @@ public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span /// public static void Subtract(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the sum of all elements in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. @@ -913,7 +913,7 @@ public static void Subtract(ReadOnlySpan x, float y, Span destinat /// /// public static float Sum(ReadOnlySpan x) => - Aggregate(x); + Aggregate(x); /// Computes the sum of the absolute values of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. @@ -934,7 +934,7 @@ public static float Sum(ReadOnlySpan x) => /// /// public static float SumOfMagnitudes(ReadOnlySpan x) => - Aggregate(x); + Aggregate(x); /// Computes the sum of the square of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. @@ -955,7 +955,7 @@ public static float SumOfMagnitudes(ReadOnlySpan x) => /// /// public static float SumOfSquares(ReadOnlySpan x) => - Aggregate(x); + Aggregate(x); /// Computes the element-wise hyperbolic tangent of each single-precision floating-point radian angle in the specified tensor. /// The tensor, represented as a span. @@ -980,78 +980,6 @@ public static float SumOfSquares(ReadOnlySpan x) => /// /// public static void Tanh(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); - - /// Throws an exception if the and spans overlap and don't begin at the same memory location. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) - { - if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && - input.Overlaps(output)) - { - ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); - } - } - - /// Mask used to handle alignment elements before vectorized handling of the input. - /// - /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the - /// beginning of the input, where elements in the vector after that will be zero'd. - /// - /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is - /// done because it allows the main algorithms to use a simplified algorithm when computing the amount - /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't - /// double process them. This allows us to avoid an additional branch. - /// - private static ReadOnlySpan AlignmentUInt32Mask_16x16 => - [ - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - ]; - - /// Mask used to handle remaining elements after vectorized handling of the input. - /// - /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the - /// end of the input, where elements in the vector prior to that will be zero'd. - /// - /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of - /// the first. Doing this allows us to avoid an additional branch and instead to always process the - /// last 16 elements via a conditional select instead. - /// - private static ReadOnlySpan RemainderUInt32Mask_16x16 => - [ - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, - ]; + InvokeSpanIntoSpan(x, destination); } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs new file mode 100644 index 00000000000000..3747906c5317e6 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// This file exists to enable TensorPrimitives.float.cs to be compiled for both +// netstandard2.0 and net8.0+ targets. It uses the XX_Single names and the operation +// methods tied to float, whereas the net8.0+ worker implementations use generic math. +// This file provides float-bound types and type defs that route one to the other. + +global using AbsoluteOperator_Single = System.Numerics.Tensors.TensorPrimitives.AbsoluteOperator; +global using AddOperator_Single = System.Numerics.Tensors.TensorPrimitives.AddOperator; +global using AddMultiplyOperator_Single = System.Numerics.Tensors.TensorPrimitives.AddMultiplyOperator; +global using CoshOperator_Single = System.Numerics.Tensors.TensorPrimitives.CoshOperator; +global using SubtractSquaredOperator_Single = System.Numerics.Tensors.TensorPrimitives.SubtractSquaredOperator; +global using DivideOperator_Single = System.Numerics.Tensors.TensorPrimitives.DivideOperator; +global using MultiplyOperator_Single = System.Numerics.Tensors.TensorPrimitives.MultiplyOperator; +global using ExpOperator_Single = System.Numerics.Tensors.TensorPrimitives.ExpOperator; +global using LogOperator_Single = System.Numerics.Tensors.TensorPrimitives.LogOperator; +global using Log2Operator_Single = System.Numerics.Tensors.TensorPrimitives.Log2Operator; +global using MaxOperator_Single = System.Numerics.Tensors.TensorPrimitives.MaxOperator; +global using MaxPropagateNaNOperator_Single = System.Numerics.Tensors.TensorPrimitives.MaxPropagateNaNOperator; +global using MaxMagnitudeOperator_Single = System.Numerics.Tensors.TensorPrimitives.MaxMagnitudeOperator; +global using MaxMagnitudePropagateNaNOperator_Single = System.Numerics.Tensors.TensorPrimitives.MaxMagnitudePropagateNaNOperator; +global using MinOperator_Single = System.Numerics.Tensors.TensorPrimitives.MinOperator; +global using MinPropagateNaNOperator_Single = System.Numerics.Tensors.TensorPrimitives.MinPropagateNaNOperator; +global using MinMagnitudeOperator_Single = System.Numerics.Tensors.TensorPrimitives.MinMagnitudeOperator; +global using MinMagnitudePropagateNaNOperator_Single = System.Numerics.Tensors.TensorPrimitives.MinMagnitudePropagateNaNOperator; +global using MultiplyAddOperator_Single = System.Numerics.Tensors.TensorPrimitives.MultiplyAddOperator; +global using NegateOperator_Single = System.Numerics.Tensors.TensorPrimitives.NegateOperator; +global using IdentityOperator_Single = System.Numerics.Tensors.TensorPrimitives.IdentityOperator; +global using SubtractOperator_Single = System.Numerics.Tensors.TensorPrimitives.SubtractOperator; +global using SigmoidOperator_Single = System.Numerics.Tensors.TensorPrimitives.SigmoidOperator; +global using SinhOperator_Single = System.Numerics.Tensors.TensorPrimitives.SinhOperator; +global using SquaredOperator_Single = System.Numerics.Tensors.TensorPrimitives.SquaredOperator; +global using TanhOperator_Single = System.Numerics.Tensors.TensorPrimitives.TanhOperator; + +// TODO: These should be made generic. Their implementations are still currently bound to float. +global using IndexOfMaxOperator_Single = System.Numerics.Tensors.TensorPrimitives.IndexOfMaxOperator; +global using IndexOfMaxMagnitudeOperator_Single = System.Numerics.Tensors.TensorPrimitives.IndexOfMaxMagnitudeOperator; +global using IndexOfMinOperator_Single = System.Numerics.Tensors.TensorPrimitives.IndexOfMinOperator; +global using IndexOfMinMagnitudeOperator_Single = System.Numerics.Tensors.TensorPrimitives.IndexOfMinMagnitudeOperator; + +namespace System.Numerics.Tensors +{ + public static unsafe partial class TensorPrimitives + { + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TSingleUnaryOperator : struct, IUnaryOperator => + InvokeSpanIntoSpan(x, destination); + + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TSingleBinaryOperator : struct, IBinaryOperator => + InvokeSpanSpanIntoSpan(x, y, destination); + + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TSingleBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan, TSingleBinaryOperator>(x, y, destination); + + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TSingleTransformOperator : struct, IUnaryOperator + where TSingleBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination); + + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TSingleTernaryOperator : struct, ITernaryOperator => + InvokeSpanSpanSpanIntoSpan(x, y, z, destination); + + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TSingleTernaryOperator : struct, ITernaryOperator => + InvokeSpanSpanScalarIntoSpan(x, y, z, destination); + + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TSingleTernaryOperator : struct, ITernaryOperator => + InvokeSpanScalarSpanIntoSpan(x, y, z, destination); + + private static unsafe float Aggregate( + ReadOnlySpan x) + where TSingleTransformOperator : struct, IUnaryOperator + where TSingleAggregationOperator : struct, IAggregationOperator => + Aggregate(x); + + private static float Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TSingleBinaryOperator : struct, IBinaryOperator + where TSingleAggregationOperator : struct, IAggregationOperator => + Aggregate(x, y); + + private static float MinMaxCore(ReadOnlySpan x) + where TSingleMinMaxOperator : struct, IAggregationOperator => + MinMaxCore(x); + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs new file mode 100644 index 00000000000000..f9a44e8680123d --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs @@ -0,0 +1,1036 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// TODO: +// - Provide generic overloads for the IndexOfMin/Max{Magnitude} methods +namespace System.Numerics.Tensors +{ + /// Performs primitive tensor operations over spans of memory. + public static partial class TensorPrimitives + { + /// Computes the element-wise absolute value of each number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// is a signed integer type and contained a value equal to 's minimum value. + /// + /// + /// This method effectively computes [i] = MathF.Abs([i]). + /// + /// + /// The absolute value of a is its numeric value without its sign. For example, the absolute value of both 1.2e-03 and -1.2e03 is 1.2e03. + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. + /// + /// + public static void Abs(ReadOnlySpan x, Span destination) + where T : INumberBase => + InvokeSpanIntoSpan>(x, destination); + + /// Computes the element-wise addition of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : IAdditionOperators, IAdditiveIdentity => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Computes the element-wise addition of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, T y, Span destination) + where T : IAdditionOperators, IAdditiveIdentity => + InvokeSpanScalarIntoSpan>(x, y, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and the length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanSpanSpanIntoSpan>(x, y, multiplier, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * . + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, T multiplier, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanSpanScalarIntoSpan>(x, y, multiplier, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + ) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, T y, ReadOnlySpan multiplier, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanScalarSpanIntoSpan>(x, y, multiplier, destination); + + /// Computes the element-wise hyperbolic cosine of each radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Cosh([i]). + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is also NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Cosh(ReadOnlySpan x, Span destination) + where T : IHyperbolicFunctions => + InvokeSpanIntoSpan>(x, destination); + + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The cosine similarity of the two tensors. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes TensorPrimitives.Dot(x, y) / (MathF.Sqrt(TensorPrimitives.SumOfSquares(x)) * MathF.Sqrt(TensorPrimitives.SumOfSquares(y)). + /// + /// + /// If any element in either input tensor is equal to , , or , + /// NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + where T : IRootFunctions => + CosineSimilarityCore(x, y); + + /// Computes the distance between two points, specified as non-empty, equal-length tensors of numbers, in Euclidean space. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The Euclidean distance. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<T> difference = ...; + /// TensorPrimitives.Subtract(x, y, difference); + /// T result = MathF.Sqrt(TensorPrimitives.SumOfSquares(difference)); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// If any element in either input tensor is equal to , NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Distance(ReadOnlySpan x, ReadOnlySpan y) + where T : IRootFunctions + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return T.Sqrt(Aggregate, AddOperator>(x, y)); + } + + /// Computes the element-wise division of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// is an integer type and an element in is equal to zero. + /// + /// + /// This method effectively computes [i] = [i] / [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : IDivisionOperators => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Computes the element-wise division of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// is an integer type and is equal to zero. + /// + /// + /// This method effectively computes [i] = [i] / . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, T y, Span destination) + where T : IDivisionOperators => + InvokeSpanScalarIntoSpan>(x, y, destination); + + /// Computes the dot product of two tensors containing numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The dot product. + /// Length of must be same as length of . + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<T> products = ...; + /// TensorPrimitives.Multiply(x, y, products); + /// T result = TensorPrimitives.Sum(products); + /// + /// but without requiring additional temporary storage for the intermediate products. It corresponds to the dot method defined by BLAS1. + /// + /// + /// If any of the input elements is equal to , the resulting value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Dot(ReadOnlySpan x, ReadOnlySpan y) + where T : IAdditionOperators, IAdditiveIdentity, IMultiplyOperators, IMultiplicativeIdentity => + Aggregate, AddOperator>(x, y); + + /// Computes the element-wise result of raising e to the number powers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Exp([i]). + /// + /// + /// If a value equals or , the result stored into the corresponding destination location is set to NaN. + /// If a value equals , the result stored into the corresponding destination location is set to 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Exp(ReadOnlySpan x, Span destination) + where T : IExponentialFunctions => + InvokeSpanIntoSpan>(x, destination); + + // TODO: Make IndexOfXx methods generic + // + ///// Searches for the index of the largest number in the specified tensor. + ///// The tensor, represented as a span. + ///// The index of the maximum element in , or -1 if is empty. + ///// + ///// + ///// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + ///// is present, the index of the first is returned. Positive 0 is considered greater than negative 0. + ///// + ///// + ///// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + ///// operating systems or architectures. + ///// + ///// + //public static int IndexOfMax(ReadOnlySpan x) + // where T : INumber => + // IndexOfMinMaxCore>(x); + // + ///// Searches for the index of the number with the largest magnitude in the specified tensor. + ///// The tensor, represented as a span. + ///// The index of the element in with the largest magnitude (absolute value), or -1 if is empty. + ///// + ///// + ///// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + ///// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + ///// the positive value is considered to have the larger magnitude. + ///// + ///// + ///// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + ///// operating systems or architectures. + ///// + ///// + //public static int IndexOfMaxMagnitude(ReadOnlySpan x) + // where T : INumberBase => + // IndexOfMinMaxCore>(x); + // + ///// Searches for the index of the smallest number in the specified tensor. + ///// The tensor, represented as a span. + ///// The index of the minimum element in , or -1 if is empty. + ///// + ///// + ///// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value equal to + ///// is present, the index of the first is returned. Negative 0 is considered smaller than positive 0. + ///// + ///// + ///// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + ///// operating systems or architectures. + ///// + ///// + //public static int IndexOfMin(ReadOnlySpan x) + // where T : INumber => + // IndexOfMinMaxCore>(x); + // + ///// Searches for the index of the number with the smallest magnitude in the specified tensor. + ///// The tensor, represented as a span. + ///// The index of the element in with the smallest magnitude (absolute value), or -1 if is empty. + ///// + ///// + ///// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + ///// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + ///// the negative value is considered to have the smaller magnitude. + ///// + ///// + ///// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + ///// operating systems or architectures. + ///// + ///// + //public static int IndexOfMinMagnitude(ReadOnlySpan x) + // where T : INumberBase => + // IndexOfMinMaxCore>(x); + + /// Computes the element-wise natural (base e) logarithm of numbers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log(ReadOnlySpan x, Span destination) + where T : ILogarithmicFunctions => + InvokeSpanIntoSpan>(x, destination); + + /// Computes the element-wise base 2 logarithm of numbers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log2([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log2(ReadOnlySpan x, Span destination) + where T : ILogarithmicFunctions => + InvokeSpanIntoSpan>(x, destination); + + /// Searches for the largest number in the specified tensor. + /// The tensor, represented as a span. + /// The maximum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Max(ReadOnlySpan x) + where T : INumber => + MinMaxCore>(x); + + /// Computes the element-wise maximum of the numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : INumber => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Searches for the number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the largest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T MaxMagnitude(ReadOnlySpan x) + where T : INumberBase => + MinMaxCore>(x); + + /// Computes the element-wise number with the largest magnitude in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : INumberBase => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Searches for the smallest number in the specified tensor. + /// The tensor, represented as a span. + /// The minimum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value is equal to + /// is present, the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Min(ReadOnlySpan x) + where T : INumber => + MinMaxCore>(x); + + /// Computes the element-wise minimum of the numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : INumber => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Searches for the number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the smallest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T MinMagnitude(ReadOnlySpan x) + where T : INumberBase => + MinMaxCore>(x); + + /// Computes the element-wise number with the smallest magnitude in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MinMagnitude([i], [i]). + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If either value is equal to , + /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : INumberBase => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Computes the element-wise product of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : IMultiplyOperators, IMultiplicativeIdentity => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Computes the element-wise product of numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * . + /// It corresponds to the scal method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, T y, Span destination) + where T : IMultiplyOperators, IMultiplicativeIdentity => + InvokeSpanScalarIntoSpan>(x, y, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanSpanSpanIntoSpan>(x, y, addend, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + . + /// It corresponds to the axpy method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, T addend, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanSpanScalarIntoSpan>(x, y, addend, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * ) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, T y, ReadOnlySpan addend, Span destination) + where T : IAdditionOperators, IMultiplyOperators => + InvokeSpanScalarSpanIntoSpan>(x, y, addend, destination); + + /// Computes the element-wise negation of each number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = -[i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Negate(ReadOnlySpan x, Span destination) + where T : IUnaryNegationOperators => + InvokeSpanIntoSpan>(x, destination); + + /// Computes the Euclidean norm of the specified tensor of numbers. + /// The first tensor, represented as a span. + /// The norm. + /// + /// + /// This method effectively computes MathF.Sqrt(TensorPrimitives.SumOfSquares(x)). + /// This is often referred to as the Euclidean norm or L2 norm. + /// It corresponds to the nrm2 method defined by BLAS1. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Norm(ReadOnlySpan x) + where T : IRootFunctions => + T.Sqrt(SumOfSquares(x)); + + /// Computes the product of all elements in the specified non-empty tensor of numbers. + /// The tensor, represented as a span. + /// The result of multiplying all elements in . + /// Length of must be greater than zero. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Product(ReadOnlySpan x) + where T : IMultiplyOperators, IMultiplicativeIdentity + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate, MultiplyOperator>(x); + } + + /// Computes the product of the element-wise differences of the numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<T> differences = ...; + /// TensorPrimitives.Subtract(x, y, differences); + /// T result = TensorPrimitives.Product(differences); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) + where T : ISubtractionOperators, IMultiplyOperators, IMultiplicativeIdentity + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate, MultiplyOperator>(x, y); + } + + /// Computes the product of the element-wise sums of the numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise additions of the elements in each tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<T> sums = ...; + /// TensorPrimitives.Add(x, y, sums); + /// T result = TensorPrimitives.Product(sums); + /// + /// but without requiring additional temporary storage for the intermediate sums. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) + where T : IAdditionOperators, IAdditiveIdentity, IMultiplyOperators, IMultiplicativeIdentity + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate, MultiplyOperator>(x, y); + } + + /// Computes the element-wise sigmoid function on the specified non-empty tensor of numbers. + /// The tensor, represented as a span. + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sigmoid(ReadOnlySpan x, Span destination) + where T : IExponentialFunctions + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + InvokeSpanIntoSpan>(x, destination); + } + + /// Computes the element-wise hyperbolic sine of each radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Sinh([i]). + /// + /// + /// If a value is equal to , , or , + /// the corresponding destination location is set to that value. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sinh(ReadOnlySpan x, Span destination) + where T : IHyperbolicFunctions => + InvokeSpanIntoSpan>(x, destination); + + /// Computes the softmax function over the specified non-empty tensor of numbers. + /// The tensor, represented as a span. + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes a sum of MathF.Exp(x[i]) for all elements in . + /// It then effectively computes [i] = MathF.Exp([i]) / sum. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void SoftMax(ReadOnlySpan x, Span destination) + where T : IExponentialFunctions + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + T expSum = Aggregate, AddOperator>(x); + + InvokeSpanScalarIntoSpan, DivideOperator>(x, expSum, destination); + } + + /// Computes the element-wise difference between numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) + where T : ISubtractionOperators => + InvokeSpanSpanIntoSpan>(x, y, destination); + + /// Computes the element-wise difference between numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, T y, Span destination) + where T : ISubtractionOperators => + InvokeSpanScalarIntoSpan>(x, y, destination); + + /// Computes the sum of all elements in the specified tensor of numbers. + /// The tensor, represented as a span. + /// The result of adding all elements in , or zero if is empty. + /// + /// + /// If any of the values in the input is equal to , the result is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T Sum(ReadOnlySpan x) + where T : IAdditionOperators, IAdditiveIdentity => + Aggregate, AddOperator>(x); + + /// Computes the sum of the absolute values of every element in the specified tensor of numbers. + /// The tensor, represented as a span. + /// The result of adding the absolute value of every element in , or zero if is empty. + /// is a signed integer type and contained a value equal to 's minimum value. + /// + /// + /// This method effectively computes: + /// + /// Span<T> absoluteValues = ...; + /// TensorPrimitives.Abs(x, absoluteValues); + /// T result = TensorPrimitives.Sum(absoluteValues); + /// + /// but without requiring intermediate storage for the absolute values. It corresponds to the asum method defined by BLAS1. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T SumOfMagnitudes(ReadOnlySpan x) + where T : INumberBase => + Aggregate, AddOperator>(x); + + /// Computes the sum of the square of every element in the specified tensor of numbers. + /// The tensor, represented as a span. + /// The result of adding the square of every element in , or zero if is empty. + /// + /// + /// This method effectively computes: + /// + /// Span<T> squaredValues = ...; + /// TensorPrimitives.Multiply(x, x, squaredValues); + /// T result = TensorPrimitives.Sum(squaredValues); + /// + /// but without requiring intermediate storage for the squared values. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static T SumOfSquares(ReadOnlySpan x) + where T : IAdditionOperators, IAdditiveIdentity, IMultiplyOperators => + Aggregate, AddOperator>(x); + + /// Computes the element-wise hyperbolic tangent of each radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Tanh([i]). + /// + /// + /// If a value is equal to , the corresponding destination location is set to -1. + /// If a value is equal to , the corresponding destination location is set to 1. + /// If a value is equal to , the corresponding destination location is set to NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Tanh(ReadOnlySpan x, Span destination) + where T : IHyperbolicFunctions => + InvokeSpanIntoSpan>(x, destination); + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs similarity index 65% rename from src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs rename to src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs index 9b275e165158af..e8fec7a4bde585 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs @@ -8,6 +8,13 @@ using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; +#pragma warning disable CS8500 // This takes the address of, gets the size of, or declares a pointer to a managed type + +// TODO: +// - Vectorize the trig-related functions for Ts other than floats +// - Vectorize integer operations when sizeof(T) == 1 or 2 (currently only vectorized in most operations for sizeof(T) == 4 or 8). +// - Implement generic version of IndexOfMinMaxCore and corresponding IndexOf methods. + namespace System.Numerics.Tensors { public static unsafe partial class TensorPrimitives @@ -615,7 +622,7 @@ static Vector512 HalfAsWidenedUInt32ToSingle_Vector512(Vector512 va /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. /// Assumes arguments have already been validated to be non-empty and equal length. - private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) + private static T CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) where T : IRootFunctions { if (x.IsEmpty) { @@ -631,38 +638,38 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector512.Count) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && x.Length >= Vector512.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); - Vector512 dotProductVector = Vector512.Zero; - Vector512 xSumOfSquaresVector = Vector512.Zero; - Vector512 ySumOfSquaresVector = Vector512.Zero; + Vector512 dotProductVector = Vector512.Zero; + Vector512 xSumOfSquaresVector = Vector512.Zero; + Vector512 ySumOfSquaresVector = Vector512.Zero; // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector512.Count; + int oneVectorFromEnd = x.Length - Vector512.Count; int i = 0; do { - Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); - Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector512.Count; + i += Vector512.Count; } while (i <= oneVectorFromEnd); // Process the last vector in the span, masking off elements already processed. if (i != x.Length) { - Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); - Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); - Vector512 remainderMask = CreateRemainderMaskSingleVector512(x.Length - i); + Vector512 remainderMask = CreateRemainderMaskVector512(x.Length - i); xVec &= remainderMask; yVec &= remainderMask; @@ -674,41 +681,41 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector256.Count) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && x.Length >= Vector256.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); - Vector256 dotProductVector = Vector256.Zero; - Vector256 xSumOfSquaresVector = Vector256.Zero; - Vector256 ySumOfSquaresVector = Vector256.Zero; + Vector256 dotProductVector = Vector256.Zero; + Vector256 xSumOfSquaresVector = Vector256.Zero; + Vector256 ySumOfSquaresVector = Vector256.Zero; // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector256.Count; + int oneVectorFromEnd = x.Length - Vector256.Count; int i = 0; do { - Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); - Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector256.Count; + i += Vector256.Count; } while (i <= oneVectorFromEnd); // Process the last vector in the span, masking off elements already processed. if (i != x.Length) { - Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); - Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); - Vector256 remainderMask = CreateRemainderMaskSingleVector256(x.Length - i); + Vector256 remainderMask = CreateRemainderMaskVector256(x.Length - i); xVec &= remainderMask; yVec &= remainderMask; @@ -720,41 +727,41 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector128.Count) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && x.Length >= Vector128.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); - Vector128 dotProductVector = Vector128.Zero; - Vector128 xSumOfSquaresVector = Vector128.Zero; - Vector128 ySumOfSquaresVector = Vector128.Zero; + Vector128 dotProductVector = Vector128.Zero; + Vector128 xSumOfSquaresVector = Vector128.Zero; + Vector128 ySumOfSquaresVector = Vector128.Zero; // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector128.Count; + int oneVectorFromEnd = x.Length - Vector128.Count; int i = 0; do { - Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); - Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector128.Count; + i += Vector128.Count; } while (i <= oneVectorFromEnd); // Process the last vector in the span, masking off elements already processed. if (i != x.Length) { - Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); - Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); - Vector128 remainderMask = CreateRemainderMaskSingleVector128(x.Length - i); + Vector128 remainderMask = CreateRemainderMaskVector128(x.Length - i); xVec &= remainderMask; yVec &= remainderMask; @@ -766,50 +773,51 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpanPerforms an aggregation over all elements in to produce a single-precision floating-point value. + /// The element type. /// Specifies the transform operation that should be applied to each element loaded from . /// /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. /// The aggregation is applied after the transform is applied to each element. /// - private static float Aggregate( - ReadOnlySpan x) - where TTransformOperator : struct, IUnaryOperator - where TAggregationOperator : struct, IAggregationOperator + private static T Aggregate( + ReadOnlySpan x) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator { // Since every branch has a cost and since that cost is // essentially lost for larger inputs, we do branches // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); + ref T xRef = ref MemoryMarshal.GetReference(x); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { result = Vectorized512(ref xRef, remainder); } @@ -825,11 +833,11 @@ private static float Aggregate( return result; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { result = Vectorized256(ref xRef, remainder); } @@ -845,11 +853,11 @@ private static float Aggregate( return result; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { result = Vectorized128(ref xRef, remainder); } @@ -871,9 +879,9 @@ private static float Aggregate( return SoftwareFallback(ref xRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float SoftwareFallback(ref float xRef, nuint length) + static T SoftwareFallback(ref T xRef, nuint length) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; for (nuint i = 0; i < length; i++) { @@ -883,31 +891,31 @@ static float SoftwareFallback(ref float xRef, nuint length) return result; } - static float Vectorized128(ref float xRef, nuint remainder) + static T Vectorized128(ref T xRef, nuint remainder) { - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) + fixed (T* px = &xRef) { - float* xPtr = px; + T* xPtr = px; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -917,30 +925,30 @@ static float Vectorized128(ref float xRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector128) - ((nuint)xPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); - vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); - vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); - vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -949,10 +957,10 @@ static float Vectorized128(ref float xRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); - vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); - vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); - vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -962,9 +970,9 @@ static float Vectorized128(ref float xRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -976,7 +984,7 @@ static float Vectorized128(ref float xRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + beg = Vector128.ConditionalSelect(CreateAlignmentMaskVector128((int)misalignment), beg, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -985,7 +993,7 @@ static float Vectorized128(ref float xRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector128.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -993,49 +1001,49 @@ static float Vectorized128(ref float xRef, nuint remainder) { case 7: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -1043,7 +1051,7 @@ static float Vectorized128(ref float xRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)trailing), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -1053,9 +1061,9 @@ static float Vectorized128(ref float xRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized128Small(ref float xRef, nuint remainder) + static T Vectorized128Small(ref T xRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -1086,31 +1094,31 @@ static float Vectorized128Small(ref float xRef, nuint remainder) return result; } - static float Vectorized256(ref float xRef, nuint remainder) + static T Vectorized256(ref T xRef, nuint remainder) { - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) + fixed (T* px = &xRef) { - float* xPtr = px; + T* xPtr = px; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -1120,30 +1128,30 @@ static float Vectorized256(ref float xRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector256) - ((nuint)xPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); - vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); - vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); - vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1152,10 +1160,10 @@ static float Vectorized256(ref float xRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); - vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); - vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); - vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1165,9 +1173,9 @@ static float Vectorized256(ref float xRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -1179,7 +1187,7 @@ static float Vectorized256(ref float xRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + beg = Vector256.ConditionalSelect(CreateAlignmentMaskVector256((int)misalignment), beg, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -1188,7 +1196,7 @@ static float Vectorized256(ref float xRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector256.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -1196,49 +1204,49 @@ static float Vectorized256(ref float xRef, nuint remainder) { case 7: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -1246,7 +1254,7 @@ static float Vectorized256(ref float xRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + end = Vector256.ConditionalSelect(CreateRemainderMaskVector256((int)trailing), end, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -1256,9 +1264,9 @@ static float Vectorized256(ref float xRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized256Small(ref float xRef, nuint remainder) + static T Vectorized256Small(ref T xRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -1267,12 +1275,12 @@ static float Vectorized256Small(ref float xRef, nuint remainder) case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)(remainder % (uint)Vector128.Count)), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -1284,9 +1292,9 @@ static float Vectorized256Small(ref float xRef, nuint remainder) case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -1320,31 +1328,31 @@ static float Vectorized256Small(ref float xRef, nuint remainder) return result; } - static float Vectorized512(ref float xRef, nuint remainder) + static T Vectorized512(ref T xRef, nuint remainder) { - Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); - Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) + fixed (T* px = &xRef) { - float* xPtr = px; + T* xPtr = px; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -1354,30 +1362,30 @@ static float Vectorized512(ref float xRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector512) - ((nuint)xPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); - vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); - vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); - vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1386,10 +1394,10 @@ static float Vectorized512(ref float xRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); - vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); - vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); - vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1399,9 +1407,9 @@ static float Vectorized512(ref float xRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -1413,7 +1421,7 @@ static float Vectorized512(ref float xRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + beg = Vector512.ConditionalSelect(CreateAlignmentMaskVector512((int)misalignment), beg, Vector512.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -1422,7 +1430,7 @@ static float Vectorized512(ref float xRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector512.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -1430,49 +1438,49 @@ static float Vectorized512(ref float xRef, nuint remainder) { case 7: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -1480,7 +1488,7 @@ static float Vectorized512(ref float xRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + end = Vector512.ConditionalSelect(CreateRemainderMaskVector512((int)trailing), end, Vector512.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -1490,9 +1498,9 @@ static float Vectorized512(ref float xRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized512Small(ref float xRef, nuint remainder) + static T Vectorized512Small(ref T xRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -1505,12 +1513,12 @@ static float Vectorized512Small(ref float xRef, nuint remainder) case 9: { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); - end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + end = Vector256.ConditionalSelect(CreateRemainderMaskVector256((int)(remainder % (uint)Vector256.Count)), end, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -1522,9 +1530,9 @@ static float Vectorized512Small(ref float xRef, nuint remainder) case 8: { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -1536,12 +1544,12 @@ static float Vectorized512Small(ref float xRef, nuint remainder) case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)(remainder % (uint)Vector128.Count)), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -1553,9 +1561,9 @@ static float Vectorized512Small(ref float xRef, nuint remainder) case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -1591,15 +1599,16 @@ static float Vectorized512Small(ref float xRef, nuint remainder) } /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// The element type. /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . /// /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. /// The aggregation is applied to the results of the binary operations on the pair-wise values. /// - private static float Aggregate( - ReadOnlySpan x, ReadOnlySpan y) - where TBinaryOperator : struct, IBinaryOperator - where TAggregationOperator : struct, IAggregationOperator + private static T Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator { if (x.Length != y.Length) { @@ -1611,16 +1620,16 @@ private static float Aggregate( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { result = Vectorized512(ref xRef, ref yRef, remainder); } @@ -1636,11 +1645,11 @@ private static float Aggregate( return result; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { result = Vectorized256(ref xRef, ref yRef, remainder); } @@ -1656,11 +1665,11 @@ private static float Aggregate( return result; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Unsafe.SizeOf() >= 4) { - float result; + T result; - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { result = Vectorized128(ref xRef, ref yRef, remainder); } @@ -1682,9 +1691,9 @@ private static float Aggregate( return SoftwareFallback(ref xRef, ref yRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) + static T SoftwareFallback(ref T xRef, ref T yRef, nuint length) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; for (nuint i = 0; i < length; i++) { @@ -1695,35 +1704,35 @@ static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) return result; } - static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized128(ref T xRef, ref T yRef, nuint remainder) { - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) { - float* xPtr = px; - float* yPtr = py; + T* xPtr = px; + T* yPtr = py; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -1733,35 +1742,35 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector128) - ((nuint)xPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1770,14 +1779,14 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1787,10 +1796,10 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -1803,7 +1812,7 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + beg = Vector128.ConditionalSelect(CreateAlignmentMaskVector128((int)misalignment), beg, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -1812,7 +1821,7 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector128.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -1820,56 +1829,56 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) { case 7: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -1877,7 +1886,7 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)trailing), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -1887,9 +1896,9 @@ static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized128Small(ref T xRef, ref T yRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -1922,35 +1931,35 @@ static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) return result; } - static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized256(ref T xRef, ref T yRef, nuint remainder) { - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); - Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) { - float* xPtr = px; - float* yPtr = py; + T* xPtr = px; + T* yPtr = py; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -1960,35 +1969,35 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector256) - ((nuint)xPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -1997,14 +2006,14 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -2014,10 +2023,10 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -2030,7 +2039,7 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + beg = Vector256.ConditionalSelect(CreateAlignmentMaskVector256((int)misalignment), beg, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -2039,7 +2048,7 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector256.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -2047,56 +2056,56 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) { case 7: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -2104,7 +2113,7 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + end = Vector256.ConditionalSelect(CreateRemainderMaskVector256((int)trailing), end, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -2114,9 +2123,9 @@ static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized256Small(ref T xRef, ref T yRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -2125,14 +2134,14 @@ static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)(remainder % (uint)Vector128.Count)), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -2144,10 +2153,10 @@ static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -2183,35 +2192,35 @@ static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) return result; } - static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized512(ref T xRef, ref T yRef, nuint remainder) { - Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), - Vector512.LoadUnsafe(ref yRef)); - Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)Vector512.Count)); nuint misalignment = 0; - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) { - float* xPtr = px; - float* yPtr = py; + T* xPtr = px; + T* yPtr = py; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)xPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -2221,35 +2230,35 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + misalignment = ((uint)sizeof(Vector512) - ((nuint)xPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; - Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)xPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; // We only need to load, so there isn't a lot of benefit to doing non-temporal operations - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -2258,14 +2267,14 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector1); vresult = TAggregationOperator.Invoke(vresult, vector2); @@ -2275,10 +2284,10 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } // Adjusting the refs here allows us to avoid pinning for very small inputs @@ -2291,7 +2300,7 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) // Store the first block. Handling this separately simplifies the latter code as we know // they come after and so we can relegate it to full blocks or the trailing elements - beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + beg = Vector512.ConditionalSelect(CreateAlignmentMaskVector512((int)misalignment), beg, Vector512.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); // Process the remaining [0, Count * 7] elements via a jump table @@ -2300,7 +2309,7 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) // worst case end up just doing the identity operation here if there // were no trailing elements. - (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)Vector512.Count); blocks -= (misalignment == 0) ? 1u : 0u; remainder -= trailing; @@ -2308,56 +2317,56 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) { case 7: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 6; } case 6: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 5; } case 5: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 4; } case 4: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 3; } case 3: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } case 2: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 1; } case 1: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 0; } @@ -2365,7 +2374,7 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) case 0: { // Store the last block, which includes any elements that wouldn't fill a full vector - end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + end = Vector512.ConditionalSelect(CreateRemainderMaskVector512((int)trailing), end, Vector512.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, end); break; } @@ -2375,9 +2384,9 @@ static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) + static T Vectorized512Small(ref T xRef, ref T yRef, nuint remainder) { - float result = TAggregationOperator.IdentityValue; + T result = TAggregationOperator.IdentityValue; switch (remainder) { @@ -2390,14 +2399,14 @@ static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) case 9: { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); - Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count)); - end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + end = Vector256.ConditionalSelect(CreateRemainderMaskVector256((int)(remainder % (uint)Vector256.Count)), end, Vector256.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -2409,10 +2418,10 @@ static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) case 8: { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -2424,14 +2433,14 @@ static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); - end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + end = Vector128.ConditionalSelect(CreateRemainderMaskVector128((int)(remainder % (uint)Vector128.Count)), end, Vector128.Create(TAggregationOperator.IdentityValue)); vresult = TAggregationOperator.Invoke(vresult, beg); vresult = TAggregationOperator.Invoke(vresult, end); @@ -2443,10 +2452,10 @@ static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); vresult = TAggregationOperator.Invoke(vresult, beg); result = TAggregationOperator.Invoke(vresult); @@ -2484,11 +2493,12 @@ static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) } /// - /// This is the same as + /// This is the same as /// with an identity transform, except it early exits on NaN. /// - private static float MinMaxCore(ReadOnlySpan x) - where TMinMaxOperator : struct, IAggregationOperator + private static T MinMaxCore(ReadOnlySpan x) + where T : INumberBase + where TMinMaxOperator : struct, IAggregationOperator { if (x.IsEmpty) { @@ -2500,23 +2510,28 @@ private static float MinMaxCore(ReadOnlySpan x) // otherwise returns the greater of the inputs. // It treats +0 as greater than -0 as per the specification. - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && x.Length >= Vector512.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); + ref T xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector512 result = Vector512.LoadUnsafe(ref xRef, 0); - Vector512 current; + Vector512 result = Vector512.LoadUnsafe(ref xRef, 0); + Vector512 current; - Vector512 nanMask = ~Vector512.Equals(result, result); - if (nanMask != Vector512.Zero) + Vector512 nanMask; + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return result.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector512.Equals(result, result); + if (nanMask != Vector512.Zero) + { + return result.GetElement(IndexOfFirstMatch(nanMask)); + } } - int oneVectorFromEnd = x.Length - Vector512.Count; - int i = Vector512.Count; + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. while (i <= oneVectorFromEnd) @@ -2524,25 +2539,33 @@ private static float MinMaxCore(ReadOnlySpan x) // Load the next vector, and early exit on NaN. current = Vector512.LoadUnsafe(ref xRef, (uint)i); - nanMask = ~Vector512.Equals(current, current); - if (nanMask != Vector512.Zero) + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); - i += Vector512.Count; + i += Vector512.Count; } // If any elements remain, handle them in one final vector. if (i != x.Length) { - current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); - nanMask = ~Vector512.Equals(current, current); - if (nanMask != Vector512.Zero) + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); @@ -2552,23 +2575,28 @@ private static float MinMaxCore(ReadOnlySpan x) return TMinMaxOperator.Invoke(result); } - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && x.Length >= Vector256.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); + ref T xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector256 result = Vector256.LoadUnsafe(ref xRef, 0); - Vector256 current; + Vector256 result = Vector256.LoadUnsafe(ref xRef, 0); + Vector256 current; - Vector256 nanMask = ~Vector256.Equals(result, result); - if (nanMask != Vector256.Zero) + Vector256 nanMask; + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return result.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector256.Equals(result, result); + if (nanMask != Vector256.Zero) + { + return result.GetElement(IndexOfFirstMatch(nanMask)); + } } - int oneVectorFromEnd = x.Length - Vector256.Count; - int i = Vector256.Count; + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. while (i <= oneVectorFromEnd) @@ -2576,25 +2604,34 @@ private static float MinMaxCore(ReadOnlySpan x) // Load the next vector, and early exit on NaN. current = Vector256.LoadUnsafe(ref xRef, (uint)i); - nanMask = ~Vector256.Equals(current, current); - if (nanMask != Vector256.Zero) + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); - i += Vector256.Count; + i += Vector256.Count; } // If any elements remain, handle them in one final vector. if (i != x.Length) { - current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); - nanMask = ~Vector256.Equals(current, current); - if (nanMask != Vector256.Zero) + + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); @@ -2604,23 +2641,28 @@ private static float MinMaxCore(ReadOnlySpan x) return TMinMaxOperator.Invoke(result); } - if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && x.Length >= Vector128.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); + ref T xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector128 result = Vector128.LoadUnsafe(ref xRef, 0); - Vector128 current; + Vector128 result = Vector128.LoadUnsafe(ref xRef, 0); + Vector128 current; - Vector128 nanMask = ~Vector128.Equals(result, result); - if (nanMask != Vector128.Zero) + Vector128 nanMask; + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return result.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector128.Equals(result, result); + if (nanMask != Vector128.Zero) + { + return result.GetElement(IndexOfFirstMatch(nanMask)); + } } - int oneVectorFromEnd = x.Length - Vector128.Count; - int i = Vector128.Count; + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. while (i <= oneVectorFromEnd) @@ -2628,25 +2670,33 @@ private static float MinMaxCore(ReadOnlySpan x) // Load the next vector, and early exit on NaN. current = Vector128.LoadUnsafe(ref xRef, (uint)i); - nanMask = ~Vector128.Equals(current, current); - if (nanMask != Vector128.Zero) + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); - i += Vector128.Count; + i += Vector128.Count; } // If any elements remain, handle them in one final vector. if (i != x.Length) { - current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); - nanMask = ~Vector128.Equals(current, current); - if (nanMask != Vector128.Zero) + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { - return current.GetElement(IndexOfFirstMatch(nanMask)); + // Check for NaNs + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return current.GetElement(IndexOfFirstMatch(nanMask)); + } } result = TMinMaxOperator.Invoke(result, current); @@ -2657,16 +2707,16 @@ private static float MinMaxCore(ReadOnlySpan x) } // Scalar path used when either vectorization is not supported or the input is too small to vectorize. - float curResult = x[0]; - if (float.IsNaN(curResult)) + T curResult = x[0]; + if (T.IsNaN(curResult)) { return curResult; } for (int i = 1; i < x.Length; i++) { - float current = x[i]; - if (float.IsNaN(current)) + T current = x[i]; + if (T.IsNaN(current)) { return current; } @@ -2677,7 +2727,8 @@ private static float MinMaxCore(ReadOnlySpan x) return curResult; } - private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator + private static int IndexOfMinMaxCore(ReadOnlySpan x) + where TIndexOfMinMax : struct, IIndexOfOperator { if (x.IsEmpty) { @@ -2886,26 +2937,27 @@ private static int IndexOfMinMaxCore(ReadOnlySpan x) wher return curIn; } - private static int IndexOfFirstMatch(Vector128 mask) + private static int IndexOfFirstMatch(Vector128 mask) { return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); } - private static int IndexOfFirstMatch(Vector256 mask) + private static int IndexOfFirstMatch(Vector256 mask) { return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); } - private static int IndexOfFirstMatch(Vector512 mask) + private static int IndexOfFirstMatch(Vector512 mask) { return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); } /// Performs an element-wise operation on and writes the results to . + /// The element type. /// Specifies the operation to perform on each element loaded from . - private static void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination) - where TUnaryOperator : struct, IUnaryOperator + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOperator { if (x.Length > destination.Length) { @@ -2919,14 +2971,14 @@ private static void InvokeSpanIntoSpan( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T dRef = ref MemoryMarshal.GetReference(destination); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, ref dRef, remainder); } @@ -2942,9 +2994,9 @@ private static void InvokeSpanIntoSpan( return; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, ref dRef, remainder); } @@ -2960,9 +3012,9 @@ private static void InvokeSpanIntoSpan( return; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, ref dRef, remainder); } @@ -2984,7 +3036,7 @@ private static void InvokeSpanIntoSpan( SoftwareFallback(ref xRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) + static void SoftwareFallback(ref T xRef, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -2992,31 +3044,31 @@ static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) } } - static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized128(ref T xRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -3026,96 +3078,96 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -3134,63 +3186,63 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -3204,7 +3256,7 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -3233,31 +3285,31 @@ static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) } } - static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -3267,96 +3319,96 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -3375,63 +3427,63 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -3445,7 +3497,7 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -3455,11 +3507,11 @@ static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -3468,7 +3520,7 @@ static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3499,31 +3551,31 @@ static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) } } - static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); - Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count)); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -3533,96 +3585,96 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -3641,63 +3693,63 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -3711,7 +3763,7 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -3725,11 +3777,11 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -3738,7 +3790,7 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3750,11 +3802,11 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -3763,7 +3815,7 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3799,12 +3851,13 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) /// Performs an element-wise operation on and , /// and writes the results to . /// - /// + /// The element type. + /// /// Specifies the operation to perform on the pair-wise elements loaded from and . /// - private static void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination) - where TBinaryOperator : struct, IBinaryOperator + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : struct, IBinaryOperator { if (x.Length != y.Length) { @@ -3824,15 +3877,15 @@ private static void InvokeSpanSpanIntoSpan( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + ref T dRef = ref MemoryMarshal.GetReference(destination); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, ref yRef, ref dRef, remainder); } @@ -3848,9 +3901,9 @@ private static void InvokeSpanSpanIntoSpan( return; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, ref yRef, ref dRef, remainder); } @@ -3866,9 +3919,9 @@ private static void InvokeSpanSpanIntoSpan( return; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, ref yRef, ref dRef, remainder); } @@ -3890,7 +3943,7 @@ private static void InvokeSpanSpanIntoSpan( SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) + static void SoftwareFallback(ref T xRef, ref T yRef, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -3899,35 +3952,35 @@ static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nui } } - static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized128(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -3937,115 +3990,115 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -4065,70 +4118,70 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -4142,7 +4195,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -4173,35 +4226,35 @@ static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, n } } - static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); - Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count)); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -4211,115 +4264,115 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -4339,70 +4392,70 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -4416,7 +4469,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -4426,13 +4479,13 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -4441,8 +4494,8 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; @@ -4475,35 +4528,35 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, n } } - static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), - Vector512.LoadUnsafe(ref yRef)); - Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)Vector512.Count)); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -4513,115 +4566,115 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -4641,70 +4694,70 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -4718,7 +4771,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -4732,13 +4785,13 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); - Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -4747,8 +4800,8 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef)); + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; @@ -4760,13 +4813,13 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); - Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -4775,8 +4828,8 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef)); + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; @@ -4814,18 +4867,20 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n /// Performs an element-wise operation on and , /// and writes the results to . /// + /// The element type. /// /// Specifies the operation to perform on each element loaded from with . /// - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TBinaryOperator : struct, IBinaryOperator => - InvokeSpanScalarIntoSpan(x, y, destination); + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, T y, Span destination) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan, TBinaryOperator>(x, y, destination); /// /// Performs an element-wise operation on and , /// and writes the results to . /// + /// The element type. /// /// Specifies the operation to perform on each element loaded from . /// It is not used with . @@ -4833,10 +4888,10 @@ private static void InvokeSpanScalarIntoSpan( /// /// Specifies the operation to perform on the transformed value from with . /// - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TTransformOperator : struct, IUnaryOperator - where TBinaryOperator : struct, IBinaryOperator + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, T y, Span destination) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator { if (x.Length > destination.Length) { @@ -4850,14 +4905,14 @@ private static void InvokeSpanScalarIntoSpan.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, y, ref dRef, remainder); } @@ -4873,9 +4928,9 @@ private static void InvokeSpanScalarIntoSpan.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, y, ref dRef, remainder); } @@ -4891,9 +4946,9 @@ private static void InvokeSpanScalarIntoSpan.IsSupported && TTransformOperator.Vectorizable && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, y, ref dRef, remainder); } @@ -4915,7 +4970,7 @@ private static void InvokeSpanScalarIntoSpan yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), - yVec); - Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), - yVec); + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)), + yVec); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -4962,112 +5017,112 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), yVec); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), yVec); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -5086,70 +5141,70 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -5163,7 +5218,7 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, T y, ref T dRef, nuint remainder) { switch (remainder) { @@ -5194,35 +5249,35 @@ static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint re } } - static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, T y, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 yVec = Vector256.Create(y); + Vector256 yVec = Vector256.Create(y); - Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), - yVec); - Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), - yVec); + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)), + yVec); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -5232,112 +5287,112 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), yVec); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), yVec); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -5356,70 +5411,70 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -5433,7 +5488,7 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, T y, ref T dRef, nuint remainder) { switch (remainder) { @@ -5443,15 +5498,15 @@ static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), - yVec); - Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), - yVec); + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)), + yVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -5460,8 +5515,8 @@ static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), - Vector128.Create(y)); + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; @@ -5494,35 +5549,35 @@ static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint re } } - static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, T y, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 yVec = Vector512.Create(y); + Vector512 yVec = Vector512.Create(y); - Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), - yVec); - Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), - yVec); + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count)), + yVec); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* dPtr = pd; + T* xPtr = px; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -5532,112 +5587,112 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), yVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), yVec); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), yVec); - vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), yVec); - vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), yVec); - vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), yVec); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -5656,70 +5711,70 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), - yVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -5733,7 +5788,7 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, T y, ref T dRef, nuint remainder) { switch (remainder) { @@ -5747,15 +5802,15 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 yVec = Vector256.Create(y); + Vector256 yVec = Vector256.Create(y); - Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), - yVec); - Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), - yVec); + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)), + yVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -5764,8 +5819,8 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), - Vector256.Create(y)); + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); beg.StoreUnsafe(ref dRef); break; @@ -5777,15 +5832,15 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), - yVec); - Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), - yVec); + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)), + yVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -5794,8 +5849,8 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), - Vector128.Create(y)); + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; @@ -5833,13 +5888,14 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re /// Performs an element-wise operation on , , and , /// and writes the results to . /// + /// The element type. /// /// Specifies the operation to perform on the pair-wise elements loaded from , , /// and . /// - private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) - where TTernaryOperator : struct, ITernaryOperator + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator { if (x.Length != y.Length || x.Length != z.Length) { @@ -5860,16 +5916,16 @@ private static void InvokeSpanSpanSpanIntoSpan( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + ref T zRef = ref MemoryMarshal.GetReference(z); + ref T dRef = ref MemoryMarshal.GetReference(destination); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); } @@ -5885,9 +5941,9 @@ private static void InvokeSpanSpanSpanIntoSpan( return; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); } @@ -5903,9 +5959,9 @@ private static void InvokeSpanSpanSpanIntoSpan( return; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); } @@ -5927,7 +5983,7 @@ private static void InvokeSpanSpanSpanIntoSpan( SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) + static void SoftwareFallback(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -5937,39 +5993,39 @@ static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref } } - static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized128(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -5979,134 +6035,134 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - zPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - zPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -6127,77 +6183,77 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -6211,7 +6267,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -6244,39 +6300,39 @@ static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, r } } - static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - Vector256.LoadUnsafe(ref zRef)); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -6286,134 +6342,134 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (nuint)sizeof(T); xPtr += misalignment; yPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - zPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - zPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -6434,77 +6490,77 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -6518,7 +6574,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -6528,15 +6584,15 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -6545,9 +6601,9 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.LoadUnsafe(ref zRef)); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -6582,39 +6638,39 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, r } } - static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), - Vector512.LoadUnsafe(ref yRef), - Vector512.LoadUnsafe(ref zRef)); - Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)Vector512.Count), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)Vector512.Count)); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -6624,134 +6680,134 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); - - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - zPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - zPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -6772,77 +6828,77 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -6856,7 +6912,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -6870,15 +6926,15 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - Vector256.LoadUnsafe(ref zRef)); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -6887,7 +6943,7 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef), Vector256.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); @@ -6901,15 +6957,15 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef), Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -6918,7 +6974,7 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef), Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); @@ -6960,13 +7016,14 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r /// Performs an element-wise operation on , , and , /// and writes the results to . /// + /// The element type. /// /// Specifies the operation to perform on the pair-wise elements loaded from and /// with . /// - private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) - where TTernaryOperator : struct, ITernaryOperator + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, T z, Span destination) + where TTernaryOperator : struct, ITernaryOperator { if (x.Length != y.Length) { @@ -6986,15 +7043,15 @@ private static void InvokeSpanSpanScalarIntoSpan( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + ref T dRef = ref MemoryMarshal.GetReference(destination); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); } @@ -7010,9 +7067,9 @@ private static void InvokeSpanSpanScalarIntoSpan( return; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); } @@ -7028,9 +7085,9 @@ private static void InvokeSpanSpanScalarIntoSpan( return; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); } @@ -7052,7 +7109,7 @@ private static void InvokeSpanSpanScalarIntoSpan( SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + static void SoftwareFallback(ref T xRef, ref T yRef, T z, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -7062,39 +7119,39 @@ static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float } } - static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 zVec = Vector128.Create(z); + Vector128 zVec = Vector128.Create(z); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef), zVec); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), zVec); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -7104,131 +7161,131 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), zVec); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), zVec); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -7248,77 +7305,77 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -7332,7 +7389,7 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { switch (remainder) { @@ -7365,39 +7422,39 @@ static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref floa } } - static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 zVec = Vector256.Create(z); + Vector256 zVec = Vector256.Create(z); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef), zVec); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), zVec); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -7407,131 +7464,131 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), zVec); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), zVec); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -7551,77 +7608,77 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -7635,7 +7692,7 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { switch (remainder) { @@ -7645,17 +7702,17 @@ static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 zVec = Vector128.Create(z); + Vector128 zVec = Vector128.Create(z); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - zVec); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), - zVec); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + zVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -7664,9 +7721,9 @@ static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.Create(z)); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); beg.StoreUnsafe(ref dRef); break; @@ -7701,39 +7758,39 @@ static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref floa } } - static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 zVec = Vector512.Create(z); + Vector512 zVec = Vector512.Create(z); - Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), - Vector512.LoadUnsafe(ref yRef), - zVec); - Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), - zVec); + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)Vector512.Count), + zVec); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* py = &yRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* py = &yRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* yPtr = py; - float* dPtr = pd; + T* xPtr = px; + T* yPtr = py; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -7743,131 +7800,131 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), zVec); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), zVec); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), zVec); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), zVec); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), zVec); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), zVec); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), zVec); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), zVec); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), zVec); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -7887,77 +7944,77 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), - zVec); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -7971,7 +8028,7 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, ref T yRef, T z, ref T dRef, nuint remainder) { switch (remainder) { @@ -7985,17 +8042,17 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 zVec = Vector256.Create(z); + Vector256 zVec = Vector256.Create(z); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef), zVec); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), zVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -8004,7 +8061,7 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef), Vector256.Create(z)); beg.StoreUnsafe(ref dRef); @@ -8018,17 +8075,17 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 zVec = Vector128.Create(z); + Vector128 zVec = Vector128.Create(z); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef), zVec); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), zVec); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -8037,7 +8094,7 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef), Vector128.Create(z)); beg.StoreUnsafe(ref dRef); @@ -8079,13 +8136,14 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa /// Performs an element-wise operation on , , and , /// and writes the results to . /// + /// The element type. /// /// Specifies the operation to perform on the pair-wise element loaded from , with , /// and the element loaded from . /// - private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) - where TTernaryOperator : struct, ITernaryOperator + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, T y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator { if (x.Length != z.Length) { @@ -8105,15 +8163,15 @@ private static void InvokeSpanScalarSpanIntoSpan( // in a way that allows us to have the minimum possible // for small sizes - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T zRef = ref MemoryMarshal.GetReference(z); + ref T dRef = ref MemoryMarshal.GetReference(destination); - nuint remainder = (uint)(x.Length); + nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector512.Count)) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); } @@ -8129,9 +8187,9 @@ private static void InvokeSpanScalarSpanIntoSpan( return; } - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector256.Count)) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); } @@ -8147,9 +8205,9 @@ private static void InvokeSpanScalarSpanIntoSpan( return; } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Unsafe.SizeOf() >= 4) { - if (remainder >= (uint)(Vector128.Count)) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); } @@ -8171,7 +8229,7 @@ private static void InvokeSpanScalarSpanIntoSpan( SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + static void SoftwareFallback(ref T xRef, T y, ref T zRef, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -8181,39 +8239,39 @@ static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float } } - static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), yVec, Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -8223,131 +8281,131 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); xPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - zPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), yVec, - Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - zPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -8367,77 +8425,77 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)(Vector128.Count)) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -8451,7 +8509,7 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized128Small(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -8484,39 +8542,39 @@ static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref floa } } - static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 yVec = Vector256.Create(y); + Vector256 yVec = Vector256.Create(y); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), yVec, Vector256.LoadUnsafe(ref zRef)); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -8526,131 +8584,131 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); xPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - zPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), yVec, - Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - zPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -8670,77 +8728,77 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)(Vector256.Count)) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -8754,7 +8812,7 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized256Small(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -8764,17 +8822,17 @@ static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), yVec, Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -8783,7 +8841,7 @@ static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.Create(y), Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); @@ -8820,39 +8878,39 @@ static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref floa } } - static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { - ref float dRefBeg = ref dRef; + ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 yVec = Vector512.Create(y); + Vector512 yVec = Vector512.Create(y); - Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), yVec, Vector512.LoadUnsafe(ref zRef)); - Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)Vector512.Count)); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (float* px = &xRef) - fixed (float* pz = &zRef) - fixed (float* pd = &dRef) + fixed (T* px = &xRef) + fixed (T* pz = &zRef) + fixed (T* pd = &dRef) { - float* xPtr = px; - float* zPtr = pz; - float* dPtr = pd; + T* xPtr = px; + T* zPtr = pz; + T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { @@ -8862,131 +8920,131 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); xPtr += misalignment; zPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - zPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); - vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); - vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); - vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), yVec, - Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - zPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -9006,77 +9064,77 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)(Vector512.Count)) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -9090,7 +9148,7 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + static void Vectorized512Small(ref T xRef, T y, ref T zRef, ref T dRef, nuint remainder) { switch (remainder) { @@ -9104,17 +9162,17 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 yVec = Vector256.Create(y); + Vector256 yVec = Vector256.Create(y); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), yVec, Vector256.LoadUnsafe(ref zRef)); - Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -9123,7 +9181,7 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.Create(y), Vector256.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); @@ -9137,17 +9195,17 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 yVec = Vector128.Create(y); + Vector128 yVec = Vector128.Create(y); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), yVec, Vector128.LoadUnsafe(ref zRef)); - Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -9156,7 +9214,7 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.Create(y), Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); @@ -9196,16 +9254,42 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) + private static T FusedMultiplyAdd(T x, T y, T addend) where T : INumberBase + { + if (typeof(T) == typeof(Half)) + { + Half result = Half.FusedMultiplyAdd(Unsafe.As(ref x), Unsafe.As(ref y), Unsafe.As(ref addend)); + return Unsafe.As(ref result); + } + + if (typeof(T) == typeof(float)) + { + float result = float.FusedMultiplyAdd(Unsafe.As(ref x), Unsafe.As(ref y), Unsafe.As(ref addend)); + return Unsafe.As(ref result); + } + + if (typeof(T) == typeof(double)) + { + double result = double.FusedMultiplyAdd(Unsafe.As(ref x), Unsafe.As(ref y), Unsafe.As(ref addend)); + return Unsafe.As(ref result); + } + + return (x * y) + addend; + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) where T : INumberBase { if (Fma.IsSupported) { - return Fma.MultiplyAdd(x, y, addend); + if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), addend.AsSingle()).As(); + if (typeof(T) == typeof(double)) return Fma.MultiplyAdd(x.AsDouble(), y.AsDouble(), addend.AsDouble()).As(); } if (AdvSimd.IsSupported) { - return AdvSimd.FusedMultiplyAdd(addend, x, y); + if (typeof(T) == typeof(float)) return AdvSimd.FusedMultiplyAdd(addend.AsSingle(), x.AsSingle(), y.AsSingle()).As(); } return (x * y) + addend; @@ -9213,11 +9297,12 @@ private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) + private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) where T : INumberBase { if (Fma.IsSupported) { - return Fma.MultiplyAdd(x, y, addend); + if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), addend.AsSingle()).As(); + if (typeof(T) == typeof(double)) return Fma.MultiplyAdd(x.AsDouble(), y.AsDouble(), addend.AsDouble()).As(); } return (x * y) + addend; @@ -9225,58 +9310,123 @@ private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) + private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) where T : INumberBase { if (Avx512F.IsSupported) { - return Avx512F.FusedMultiplyAdd(x, y, addend); + if (typeof(T) == typeof(float)) return Avx512F.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), addend.AsSingle()).As(); + if (typeof(T) == typeof(double)) return Avx512F.FusedMultiplyAdd(x.AsDouble(), y.AsDouble(), addend.AsDouble()).As(); } return (x * y) + addend; } /// Aggregates all of the elements in the into a single value. + /// The element type. /// Specifies the operation to be performed on each pair of values. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator + private static T HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator { // We need to do log2(count) operations to compute the total sum - x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(2, 3, 0, 1))); - x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(1, 0, 3, 2))); + if (Unsafe.SizeOf() == 1) + { + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As()); + } + else if (Unsafe.SizeOf() == 2) + { + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As()); + } + else if (Unsafe.SizeOf() == 4) + { + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(2, 3, 0, 1)).As()); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(1, 0, 3, 2)).As()); + } + else if (Unsafe.SizeOf() == 8) + { + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt64(), Vector128.Create(1, 0)).As()); + } + else + { + Debug.Fail("Should not be reachable"); + throw new NotSupportedException(); + } return x.ToScalar(); } /// Aggregates all of the elements in the into a single value. + /// The element type. /// Specifies the operation to be performed on each pair of values. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static float HorizontalAggregate(Vector256 x) where TAggregate : struct, IBinaryOperator => - HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); + private static T HorizontalAggregate(Vector256 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); /// Aggregates all of the elements in the into a single value. + /// The element type. /// Specifies the operation to be performed on each pair of values. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static float HorizontalAggregate(Vector512 x) where TAggregate : struct, IBinaryOperator => - HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); + private static T HorizontalAggregate(Vector512 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); /// Gets whether the specified is negative. - private static bool IsNegative(float f) => float.IsNegative(f); + private static bool IsNegative(T f) where T : INumberBase => T.IsNegative(f); /// Gets whether each specified is negative. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 IsNegative(Vector128 vector) => - Vector128.LessThan(vector.AsInt32(), Vector128.Zero).AsSingle(); + private static Vector128 IsNegative(Vector128 vector) + { + if (typeof(T) == typeof(float)) + { + return Vector128.LessThan(vector.AsInt32(), Vector128.Zero).As(); + } + + if (typeof(T) == typeof(double)) + { + return Vector128.LessThan(vector.AsInt64(), Vector128.Zero).As(); + } + + return Vector128.LessThan(vector, Vector128.Zero); + } /// Gets whether each specified is negative. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 IsNegative(Vector256 vector) => - Vector256.LessThan(vector.AsInt32(), Vector256.Zero).AsSingle(); + private static Vector256 IsNegative(Vector256 vector) + { + if (typeof(T) == typeof(float)) + { + return Vector256.LessThan(vector.AsInt32(), Vector256.Zero).As(); + } + + if (typeof(T) == typeof(double)) + { + return Vector256.LessThan(vector.AsInt64(), Vector256.Zero).As(); + } + + return Vector256.LessThan(vector, Vector256.Zero); + } /// Gets whether each specified is negative. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 IsNegative(Vector512 vector) => - Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle(); + private static Vector512 IsNegative(Vector512 vector) + { + if (typeof(T) == typeof(float)) + { + return Vector512.LessThan(vector.AsInt32(), Vector512.Zero).As(); + } + + if (typeof(T) == typeof(double)) + { + return Vector512.LessThan(vector.AsInt64(), Vector512.Zero).As(); + } + + return Vector512.LessThan(vector, Vector512.Zero); + } /// Gets whether the specified is positive. private static bool IsPositive(float f) => float.IsPositive(f); @@ -9296,183 +9446,411 @@ private static Vector256 IsPositive(Vector256 vector) => private static Vector512 IsPositive(Vector512 vector) => Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle(); - /// Gets the base 2 logarithm of . - private static float Log2(float x) => MathF.Log2(x); - /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 CreateAlignmentMaskSingleVector128(int count) => - Vector128.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), - (uint)(count * 16)); // first four floats in the row + private static Vector128 CreateAlignmentMaskVector128(int count) + { + if (Unsafe.SizeOf() == 1) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentByteMask_64x65)), + (uint)(count * 64)); + } - /// - /// Gets a vector mask that will be all-ones-set for the last elements - /// and zero for all other elements. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 CreateAlignmentMaskSingleVector256(int count) => - Vector256.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), - (uint)(count * 16)); // first eight floats in the row + if (Unsafe.SizeOf() == 2) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt16Mask_32x33)), + (uint)(count * 32)); + } - /// - /// Gets a vector mask that will be all-ones-set for the last elements - /// and zero for all other elements. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 CreateAlignmentMaskSingleVector512(int count) => - Vector512.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), - (uint)(count * 16)); // all sixteen floats in the row + if (Unsafe.SizeOf() == 4) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17)), + (uint)(count * 16)); + } + + if (Unsafe.SizeOf() == 8) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), + (uint)(count * 8)); + } + + Debug.Fail("Shouldn't get here"); + throw new NotSupportedException(); + } /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 CreateRemainderMaskSingleVector128(int count) => - Vector128.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), - (uint)((count * 16) + 12)); // last four floats in the row + private static Vector256 CreateAlignmentMaskVector256(int count) + { + if (Unsafe.SizeOf() == 1) + { + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentByteMask_64x65)), + (uint)(count * 64)); + } + + if (Unsafe.SizeOf() == 2) + { + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt16Mask_32x33)), + (uint)(count * 32)); + } + + if (Unsafe.SizeOf() == 4) + { + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17)), + (uint)(count * 16)); + } + + if (Unsafe.SizeOf() == 8) + { + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), + (uint)(count * 8)); + } + + Debug.Fail("Shouldn't get here"); + throw new NotSupportedException(); + } /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 CreateRemainderMaskSingleVector256(int count) => - Vector256.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), - (uint)((count * 16) + 8)); // last eight floats in the row + private static Vector512 CreateAlignmentMaskVector512(int count) + { + if (Unsafe.SizeOf() == 1) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentByteMask_64x65)), + (uint)(count * 64)); + } + + if (Unsafe.SizeOf() == 2) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt16Mask_32x33)), + (uint)(count * 32)); + } + + if (Unsafe.SizeOf() == 4) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17)), + (uint)(count * 16)); + } + + if (Unsafe.SizeOf() == 8) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), + (uint)(count * 8)); + } + + Debug.Fail("Shouldn't get here - CreateAlignmentMaskVector512"); + throw new NotSupportedException(); + } /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 CreateRemainderMaskSingleVector512(int count) => - Vector512.LoadUnsafe( - ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), - (uint)(count * 16)); // all sixteen floats in the row - - /// x + y - private readonly struct AddOperator : IAggregationOperator + private static Vector128 CreateRemainderMaskVector128(int count) { - public static float Invoke(float x, float y) => x + y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; - public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; + if (Unsafe.SizeOf() == 1) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderByteMask_64x65)), + (uint)(count * 64) + 48); // last 16 bytes in the row + } + + if (Unsafe.SizeOf() == 2) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt16Mask_32x33)), + (uint)(count * 32) + 24); // last 8 shorts in the row + } - public static float Invoke(Vector128 x) => Vector128.Sum(x); - public static float Invoke(Vector256 x) => Vector256.Sum(x); - public static float Invoke(Vector512 x) => Vector512.Sum(x); + if (Unsafe.SizeOf() == 4) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17)), + (uint)(count * 16) + 12); // last 4 ints in the row + } - public static float IdentityValue => 0; - } + if (Unsafe.SizeOf() == 8) + { + return Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), + (uint)(count * 8) + 6); // last 2 longs in the row + } - /// x - y - private readonly struct SubtractOperator : IBinaryOperator - { - public static float Invoke(float x, float y) => x - y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; - public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; + Debug.Fail("Shouldn't get here - CreateRemainderMaskVector128"); + throw new NotSupportedException(); } - /// (x - y) * (x - y) - private readonly struct SubtractSquaredOperator : IBinaryOperator + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateRemainderMaskVector256(int count) { - public static float Invoke(float x, float y) + if (Unsafe.SizeOf() == 1) { - float tmp = x - y; - return tmp * tmp; + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderByteMask_64x65)), + (uint)(count * 64) + 32); // last 32 bytes in the row } - public static Vector128 Invoke(Vector128 x, Vector128 y) + if (Unsafe.SizeOf() == 2) { - Vector128 tmp = x - y; - return tmp * tmp; + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt16Mask_32x33)), + (uint)(count * 32) + 16); // last 16 shorts in the row } - public static Vector256 Invoke(Vector256 x, Vector256 y) + if (Unsafe.SizeOf() == 4) { - Vector256 tmp = x - y; - return tmp * tmp; + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17)), + (uint)(count * 16) + 8); // last 8 ints in the row } - public static Vector512 Invoke(Vector512 x, Vector512 y) + if (Unsafe.SizeOf() == 8) { - Vector512 tmp = x - y; - return tmp * tmp; + return Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), + (uint)(count * 8) + 4); // last 4 longs in the row } - } - - /// x * y - private readonly struct MultiplyOperator : IAggregationOperator - { - public static float Invoke(float x, float y) => x * y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; - public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; - public static float Invoke(Vector128 x) => HorizontalAggregate(x); - public static float Invoke(Vector256 x) => HorizontalAggregate(x); - public static float Invoke(Vector512 x) => HorizontalAggregate(x); - - public static float IdentityValue => 1; + Debug.Fail("Shouldn't get here - CreateRemainderMaskVector256"); + throw new NotSupportedException(); } - /// x / y - private readonly struct DivideOperator : IBinaryOperator + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateRemainderMaskVector512(int count) { - public static float Invoke(float x, float y) => x / y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; - public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; + if (Unsafe.SizeOf() == 1) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderByteMask_64x65)), + (uint)(count * 64)); + } + + if (Unsafe.SizeOf() == 2) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt16Mask_32x33)), + (uint)(count * 32)); + } + + if (Unsafe.SizeOf() == 4) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17)), + (uint)(count * 16)); + } + + if (Unsafe.SizeOf() == 8) + { + return Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), + (uint)(count * 8)); + } + + Debug.Fail("Shouldn't get here - CreateRemainderMaskVector512"); + throw new NotSupportedException(); } - /// MathF.Max(x, y) (but NaNs may not be propagated) - private readonly struct MaxOperator : IAggregationOperator + /// x + y + internal readonly struct AddOperator : IAggregationOperator where T : IAdditionOperators, IAdditiveIdentity { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => - x == y ? - (IsNegative(x) ? y : x) : - (y > x ? y : x); + public static T Invoke(T x, T y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; + + public static T Invoke(Vector128 x) => Vector128.Sum(x); + public static T Invoke(Vector256 x) => Vector256.Sum(x); + public static T Invoke(Vector512 x) => Vector512.Sum(x); + + public static T IdentityValue => T.AdditiveIdentity; + } + + /// x - y + internal readonly struct SubtractOperator : IBinaryOperator where T : ISubtractionOperators + { + public static T Invoke(T x, T y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; + } + + /// (x - y) * (x - y) + internal readonly struct SubtractSquaredOperator : IBinaryOperator where T : ISubtractionOperators, IMultiplyOperators + { + public static T Invoke(T x, T y) + { + T tmp = x - y; + return tmp * tmp; + } + + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 tmp = x - y; + return tmp * tmp; + } + + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 tmp = x - y; + return tmp * tmp; + } + + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 tmp = x - y; + return tmp * tmp; + } + } + + /// x * y + internal readonly struct MultiplyOperator : IAggregationOperator where T : IMultiplyOperators, IMultiplicativeIdentity + { + public static T Invoke(T x, T y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; + + public static T Invoke(Vector128 x) => HorizontalAggregate>(x); + public static T Invoke(Vector256 x) => HorizontalAggregate>(x); + public static T Invoke(Vector512 x) => HorizontalAggregate>(x); + + public static T IdentityValue => T.MultiplicativeIdentity; + } + + /// x / y + internal readonly struct DivideOperator : IBinaryOperator where T : IDivisionOperators + { + public static T Invoke(T x, T y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; + } + + /// T.Max(x, y) (but NaNs may not be propagated) + internal readonly struct MaxOperator : IAggregationOperator where T : INumber + { + public static T Invoke(T x, T y) + { + if (typeof(T) == typeof(Half) || typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + } + + return T.Max(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { if (AdvSimd.IsSupported) { - return AdvSimd.Max(x, y); + if (typeof(T) == typeof(byte)) return AdvSimd.Max(x.AsByte(), y.AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return AdvSimd.Max(x.AsSByte(), y.AsSByte()).As(); + if (typeof(T) == typeof(short)) return AdvSimd.Max(x.AsInt16(), y.AsInt16()).As(); + if (typeof(T) == typeof(ushort)) return AdvSimd.Max(x.AsUInt16(), y.AsUInt16()).As(); + if (typeof(T) == typeof(int)) return AdvSimd.Max(x.AsInt32(), y.AsInt32()).As(); + if (typeof(T) == typeof(uint)) return AdvSimd.Max(x.AsUInt32(), y.AsUInt32()).As(); + if (typeof(T) == typeof(float)) return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As(); } - return - Vector128.ConditionalSelect(Vector128.Equals(x, y), - Vector128.ConditionalSelect(IsNegative(x), y, x), - Vector128.Max(x, y)); + if (typeof(T) == typeof(float)) + { + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x.AsSingle()).As(), y, x), + Vector128.Max(x, y)); + } + + if (typeof(T) == typeof(double)) + { + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x.AsDouble()).As(), y, x), + Vector128.Max(x, y)); + } + + return Vector128.Max(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) => - Vector256.ConditionalSelect(Vector256.Equals(x, y), - Vector256.ConditionalSelect(IsNegative(x), y, x), - Vector256.Max(x, y)); + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(float)) + { + return + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x.AsSingle()).As(), y, x), + Vector256.Max(x, y)); + } + + if (typeof(T) == typeof(double)) + { + return + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x.AsDouble()).As(), y, x), + Vector256.Max(x, y)); + } + + return Vector256.Max(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) => - Vector512.ConditionalSelect(Vector512.Equals(x, y), - Vector512.ConditionalSelect(IsNegative(x), y, x), - Vector512.Max(x, y)); - - public static float Invoke(Vector128 x) => HorizontalAggregate(x); - public static float Invoke(Vector256 x) => HorizontalAggregate(x); - public static float Invoke(Vector512 x) => HorizontalAggregate(x); + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(float)) + { + return + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x.AsSingle()).As(), y, x), + Vector512.Max(x, y)); + } + + if (typeof(T) == typeof(double)) + { + return + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x.AsDouble()).As(), y, x), + Vector512.Max(x, y)); + } + + return Vector512.Max(x, y); + } + + public static T Invoke(Vector128 x) => HorizontalAggregate>(x); + public static T Invoke(Vector256 x) => HorizontalAggregate>(x); + public static T Invoke(Vector512 x) => HorizontalAggregate>(x); } private interface IIndexOfOperator @@ -9487,20 +9865,22 @@ private interface IIndexOfOperator } /// Returns the index of MathF.Max(x, y) - private readonly struct IndexOfMaxOperator : IIndexOfOperator + internal readonly struct IndexOfMaxOperator : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int Invoke(Vector128 result, Vector128 maxIndex) { - Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); - Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpResult; + Vector128 tmpIndex; + tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); - Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); } @@ -9604,7 +9984,7 @@ public static int Invoke(ref float result, float current, int resultIndex, int c } } - private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + internal readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int Invoke(Vector128 result, Vector128 maxIndex) @@ -9730,7 +10110,7 @@ public static int Invoke(ref float result, float current, int resultIndex, int c } /// Returns the index of MathF.Min(x, y) - private readonly struct IndexOfMinOperator : IIndexOfOperator + internal readonly struct IndexOfMinOperator : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int Invoke(Vector128 result, Vector128 resultIndex) @@ -9848,7 +10228,7 @@ public static int Invoke(ref float result, float current, int resultIndex, int c } } - private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + internal readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int Invoke(Vector128 result, Vector128 resultIndex) @@ -9973,385 +10353,663 @@ public static int Invoke(ref float result, float current, int resultIndex, int c } } - /// MathF.Max(x, y) - private readonly struct MaxPropagateNaNOperator : IBinaryOperator + /// Max(x, y) + internal readonly struct MaxPropagateNaNOperator : IBinaryOperator + where T : INumber { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => MathF.Max(x, y); + public static T Invoke(T x, T y) => T.Max(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { if (AdvSimd.IsSupported) { - return AdvSimd.Max(x, y); + if (typeof(T) == typeof(byte)) return AdvSimd.Max(x.AsByte(), y.AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return AdvSimd.Max(x.AsSByte(), y.AsSByte()).As(); + if (typeof(T) == typeof(ushort)) return AdvSimd.Max(x.AsUInt16(), y.AsUInt16()).As(); + if (typeof(T) == typeof(short)) return AdvSimd.Max(x.AsInt16(), y.AsInt16()).As(); + if (typeof(T) == typeof(uint)) return AdvSimd.Max(x.AsUInt32(), y.AsUInt32()).As(); + if (typeof(T) == typeof(int)) return AdvSimd.Max(x.AsInt32(), y.AsInt32()).As(); + if (typeof(T) == typeof(float)) return AdvSimd.Max(x.AsSingle(), y.AsSingle()).As(); } - return - Vector128.ConditionalSelect(Vector128.Equals(x, x), - Vector128.ConditionalSelect(Vector128.Equals(y, y), - Vector128.ConditionalSelect(Vector128.Equals(x, y), - Vector128.ConditionalSelect(IsNegative(x), y, x), - Vector128.Max(x, y)), - y), - x); + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)), + y), + x); + } + + return Vector128.Max(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) => - Vector256.ConditionalSelect(Vector256.Equals(x, x), - Vector256.ConditionalSelect(Vector256.Equals(y, y), - Vector256.ConditionalSelect(Vector256.Equals(x, y), - Vector256.ConditionalSelect(IsNegative(x), y, x), - Vector256.Max(x, y)), - y), - x); + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)), + y), + x); + } + + return Vector256.Max(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) => - Vector512.ConditionalSelect(Vector512.Equals(x, x), - Vector512.ConditionalSelect(Vector512.Equals(y, y), - Vector512.ConditionalSelect(Vector512.Equals(x, y), - Vector512.ConditionalSelect(IsNegative(x), y, x), - Vector512.Max(x, y)), - y), - x); + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)), + y), + x); + } + + return Vector512.Max(x, y); + } } /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) - private readonly struct MaxMagnitudeOperator : IAggregationOperator + internal readonly struct MaxMagnitudeOperator : IAggregationOperator + where T : INumberBase { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) - { - float xMag = MathF.Abs(x), yMag = MathF.Abs(y); - return - xMag == yMag ? - (IsNegative(x) ? y : x) : - (xMag > yMag ? x : y); - } + public static T Invoke(T x, T y) => T.MaxMagnitude(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { - Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); - return + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + + Vector128 result = Vector128.ConditionalSelect(Vector128.Equals(xMag, yMag), Vector128.ConditionalSelect(IsNegative(x), y, x), Vector128.ConditionalSelect(Vector128.GreaterThan(xMag, yMag), x, y)); + + // Handle minimum signed value that should have the largest magnitude + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector128 negativeMagnitudeX = Vector128.LessThan(xMag, Vector128.Zero); + Vector128 negativeMagnitudeY = Vector128.LessThan(yMag, Vector128.Zero); + result = Vector128.ConditionalSelect(negativeMagnitudeX, + x, + Vector128.ConditionalSelect(negativeMagnitudeY, + y, + result)); + } + + return result; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) + public static Vector256 Invoke(Vector256 x, Vector256 y) { - Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); - return + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + + Vector256 result = Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), Vector256.ConditionalSelect(IsNegative(x), y, x), Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)); + + // Handle minimum signed value that should have the largest magnitude + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector256 negativeMagnitudeX = Vector256.LessThan(xMag, Vector256.Zero); + Vector256 negativeMagnitudeY = Vector256.LessThan(yMag, Vector256.Zero); + result = Vector256.ConditionalSelect(negativeMagnitudeX, + x, + Vector256.ConditionalSelect(negativeMagnitudeY, + y, + result)); + } + + return result; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) + public static Vector512 Invoke(Vector512 x, Vector512 y) { - Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); - return + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + + Vector512 result = Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), Vector512.ConditionalSelect(IsNegative(x), y, x), Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)); + + // Handle minimum signed value that should have the largest magnitude + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector512 negativeMagnitudeX = Vector512.LessThan(xMag, Vector512.Zero); + Vector512 negativeMagnitudeY = Vector512.LessThan(yMag, Vector512.Zero); + result = Vector512.ConditionalSelect(negativeMagnitudeX, + x, + Vector512.ConditionalSelect(negativeMagnitudeY, + y, + result)); + } + + return result; } - public static float Invoke(Vector128 x) => HorizontalAggregate(x); - public static float Invoke(Vector256 x) => HorizontalAggregate(x); - public static float Invoke(Vector512 x) => HorizontalAggregate(x); + public static T Invoke(Vector128 x) => HorizontalAggregate>(x); + public static T Invoke(Vector256 x) => HorizontalAggregate>(x); + public static T Invoke(Vector512 x) => HorizontalAggregate>(x); } /// Operator to get x or y based on which has the larger MathF.Abs - private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + internal readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + where T : INumberBase { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => MathF.MaxMagnitude(x, y); + public static T Invoke(T x, T y) => T.MaxMagnitude(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { - Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); - return - Vector128.ConditionalSelect(Vector128.Equals(x, x), - Vector128.ConditionalSelect(Vector128.Equals(y, y), - Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), - Vector128.ConditionalSelect(IsNegative(x), y, x), - Vector128.ConditionalSelect(Vector128.GreaterThan(yMag, xMag), y, x)), - y), - x); + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(yMag, xMag), y, x)), + y), + x); + } + + return MaxMagnitudeOperator.Invoke(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) + public static Vector256 Invoke(Vector256 x, Vector256 y) { - Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); - return - Vector256.ConditionalSelect(Vector256.Equals(x, x), - Vector256.ConditionalSelect(Vector256.Equals(y, y), - Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), - Vector256.ConditionalSelect(IsNegative(x), y, x), - Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)), - y), - x); + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)), + y), + x); + } + + return MaxMagnitudeOperator.Invoke(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) + public static Vector512 Invoke(Vector512 x, Vector512 y) { - Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); - return - Vector512.ConditionalSelect(Vector512.Equals(x, x), + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), Vector512.ConditionalSelect(Vector512.Equals(y, y), Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), Vector512.ConditionalSelect(IsNegative(x), y, x), Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)), y), x); + } + + return MaxMagnitudeOperator.Invoke(x, y); } } - /// MathF.Min(x, y) (but NaNs may not be propagated) - private readonly struct MinOperator : IAggregationOperator + /// T.Min(x, y) (but NaNs may not be propagated) + internal readonly struct MinOperator : IAggregationOperator + where T : INumber { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => - x == y ? - (IsNegative(y) ? y : x) : - (y < x ? y : x); + public static T Invoke(T x, T y) + { + if (typeof(T) == typeof(Half) || typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + } + + return T.Min(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { if (AdvSimd.IsSupported) { - return AdvSimd.Min(x, y); + if (typeof(T) == typeof(byte)) return AdvSimd.Min(x.AsByte(), y.AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return AdvSimd.Min(x.AsSByte(), y.AsSByte()).As(); + if (typeof(T) == typeof(short)) return AdvSimd.Min(x.AsInt16(), y.AsInt16()).As(); + if (typeof(T) == typeof(ushort)) return AdvSimd.Min(x.AsUInt16(), y.AsUInt16()).As(); + if (typeof(T) == typeof(int)) return AdvSimd.Min(x.AsInt32(), y.AsInt32()).As(); + if (typeof(T) == typeof(uint)) return AdvSimd.Min(x.AsUInt32(), y.AsUInt32()).As(); + if (typeof(T) == typeof(float)) return AdvSimd.Min(x.AsSingle(), y.AsSingle()).As(); } - return - Vector128.ConditionalSelect(Vector128.Equals(x, y), - Vector128.ConditionalSelect(IsNegative(y), y, x), - Vector128.Min(x, y)); + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.Min(x, y)); + } + + return Vector128.Min(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) => - Vector256.ConditionalSelect(Vector256.Equals(x, y), - Vector256.ConditionalSelect(IsNegative(y), y, x), - Vector256.Min(x, y)); + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.Min(x, y)); + } + + return Vector256.Min(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) => - Vector512.ConditionalSelect(Vector512.Equals(x, y), - Vector512.ConditionalSelect(IsNegative(y), y, x), - Vector512.Min(x, y)); - - public static float Invoke(Vector128 x) => HorizontalAggregate(x); - public static float Invoke(Vector256 x) => HorizontalAggregate(x); - public static float Invoke(Vector512 x) => HorizontalAggregate(x); + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.Min(x, y)); + } + + return Vector512.Min(x, y); + } + + public static T Invoke(Vector128 x) => HorizontalAggregate>(x); + public static T Invoke(Vector256 x) => HorizontalAggregate>(x); + public static T Invoke(Vector512 x) => HorizontalAggregate>(x); } - /// MathF.Min(x, y) - private readonly struct MinPropagateNaNOperator : IBinaryOperator + /// T.Min(x, y) + internal readonly struct MinPropagateNaNOperator : IBinaryOperator + where T : INumber { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => MathF.Min(x, y); + public static T Invoke(T x, T y) => T.Min(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { if (AdvSimd.IsSupported) { - return AdvSimd.Min(x, y); + if (typeof(T) == typeof(byte)) return AdvSimd.Min(x.AsByte(), y.AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return AdvSimd.Min(x.AsSByte(), y.AsSByte()).As(); + if (typeof(T) == typeof(short)) return AdvSimd.Min(x.AsInt16(), y.AsInt16()).As(); + if (typeof(T) == typeof(ushort)) return AdvSimd.Min(x.AsUInt16(), y.AsUInt16()).As(); + if (typeof(T) == typeof(int)) return AdvSimd.Min(x.AsInt32(), y.AsInt32()).As(); + if (typeof(T) == typeof(uint)) return AdvSimd.Min(x.AsUInt32(), y.AsUInt32()).As(); + if (typeof(T) == typeof(float)) return AdvSimd.Min(x.AsSingle(), y.AsSingle()).As(); } - return - Vector128.ConditionalSelect(Vector128.Equals(x, x), - Vector128.ConditionalSelect(Vector128.Equals(y, y), - Vector128.ConditionalSelect(Vector128.Equals(x, y), - Vector128.ConditionalSelect(IsNegative(x), x, y), - Vector128.Min(x, y)), - y), - x); + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.Min(x, y)), + y), + x); + } + + return Vector128.Min(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) => - Vector256.ConditionalSelect(Vector256.Equals(x, x), - Vector256.ConditionalSelect(Vector256.Equals(y, y), - Vector256.ConditionalSelect(Vector256.Equals(x, y), - Vector256.ConditionalSelect(IsNegative(x), x, y), - Vector256.Min(x, y)), - y), - x); + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.Min(x, y)), + y), + x); + } + + return Vector256.Min(x, y); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) => - Vector512.ConditionalSelect(Vector512.Equals(x, x), - Vector512.ConditionalSelect(Vector512.Equals(y, y), - Vector512.ConditionalSelect(Vector512.Equals(x, y), - Vector512.ConditionalSelect(IsNegative(x), x, y), - Vector512.Min(x, y)), - y), - x); + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.Min(x, y)), + y), + x); + } + + return Vector512.Min(x, y); + } } /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) - private readonly struct MinMagnitudeOperator : IAggregationOperator + internal readonly struct MinMagnitudeOperator : IAggregationOperator + where T : INumberBase { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) - { - float xMag = MathF.Abs(x), yMag = MathF.Abs(y); - return xMag == yMag ? - (IsNegative(y) ? y : x) : - (yMag < xMag ? y : x); - } + public static T Invoke(T x, T y) => T.MinMagnitude(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { - Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); - return + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + + Vector128 result = Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), Vector128.ConditionalSelect(IsNegative(y), y, x), Vector128.ConditionalSelect(Vector128.LessThan(yMag, xMag), y, x)); + + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector128 negativeMagnitudeX = Vector128.LessThan(xMag, Vector128.Zero); + Vector128 negativeMagnitudeY = Vector128.LessThan(yMag, Vector128.Zero); + result = Vector128.ConditionalSelect(negativeMagnitudeX, + y, + Vector128.ConditionalSelect(negativeMagnitudeY, + x, + result)); + } + + return result; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) + public static Vector256 Invoke(Vector256 x, Vector256 y) { - Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); - return + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + + Vector256 result = Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), Vector256.ConditionalSelect(IsNegative(y), y, x), Vector256.ConditionalSelect(Vector256.LessThan(yMag, xMag), y, x)); + + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector256 negativeMagnitudeX = Vector256.LessThan(xMag, Vector256.Zero); + Vector256 negativeMagnitudeY = Vector256.LessThan(yMag, Vector256.Zero); + result = Vector256.ConditionalSelect(negativeMagnitudeX, + y, + Vector256.ConditionalSelect(negativeMagnitudeY, + x, + result)); + } + + return result; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) + public static Vector512 Invoke(Vector512 x, Vector512 y) { - Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); - return + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + + Vector512 result = Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), Vector512.ConditionalSelect(IsNegative(y), y, x), Vector512.ConditionalSelect(Vector512.LessThan(yMag, xMag), y, x)); + + if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(short) || typeof(T) == typeof(int) || typeof(T) == typeof(long) || typeof(T) == typeof(nint)) + { + Vector512 negativeMagnitudeX = Vector512.LessThan(xMag, Vector512.Zero); + Vector512 negativeMagnitudeY = Vector512.LessThan(yMag, Vector512.Zero); + result = Vector512.ConditionalSelect(negativeMagnitudeX, + y, + Vector512.ConditionalSelect(negativeMagnitudeY, + x, + result)); + } + + return result; } - public static float Invoke(Vector128 x) => HorizontalAggregate(x); - public static float Invoke(Vector256 x) => HorizontalAggregate(x); - public static float Invoke(Vector512 x) => HorizontalAggregate(x); + public static T Invoke(Vector128 x) => HorizontalAggregate>(x); + public static T Invoke(Vector256 x) => HorizontalAggregate>(x); + public static T Invoke(Vector512 x) => HorizontalAggregate>(x); } /// Operator to get x or y based on which has the smaller MathF.Abs - private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + internal readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + where T : INumberBase { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(float x, float y) => MathF.MinMagnitude(x, y); + public static T Invoke(T x, T y) => T.MinMagnitude(x, y); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 Invoke(Vector128 x, Vector128 y) + public static Vector128 Invoke(Vector128 x, Vector128 y) { - Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); - return - Vector128.ConditionalSelect(Vector128.Equals(x, x), - Vector128.ConditionalSelect(Vector128.Equals(y, y), - Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), - Vector128.ConditionalSelect(IsNegative(x), x, y), - Vector128.ConditionalSelect(Vector128.LessThan(xMag, yMag), x, y)), - y), - x); + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.ConditionalSelect(Vector128.LessThan(xMag, yMag), x, y)), + y), + x); + } + + return MinMagnitudeOperator.Invoke(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector256 Invoke(Vector256 x, Vector256 y) + public static Vector256 Invoke(Vector256 x, Vector256 y) { - Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); - return - Vector256.ConditionalSelect(Vector256.Equals(x, x), - Vector256.ConditionalSelect(Vector256.Equals(y, y), - Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), - Vector256.ConditionalSelect(IsNegative(x), x, y), - Vector256.ConditionalSelect(Vector256.LessThan(xMag, yMag), x, y)), - y), - x); + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.ConditionalSelect(Vector256.LessThan(xMag, yMag), x, y)), + y), + x); + } + + return MinMagnitudeOperator.Invoke(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector512 Invoke(Vector512 x, Vector512 y) + public static Vector512 Invoke(Vector512 x, Vector512 y) { - Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); - return - Vector512.ConditionalSelect(Vector512.Equals(x, x), - Vector512.ConditionalSelect(Vector512.Equals(y, y), - Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), - Vector512.ConditionalSelect(IsNegative(x), x, y), - Vector512.ConditionalSelect(Vector512.LessThan(xMag, yMag), x, y)), - y), - x); + // Handle NaNs + if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.ConditionalSelect(Vector512.LessThan(xMag, yMag), x, y)), + y), + x); + } + + return MinMagnitudeOperator.Invoke(x, y); } } /// -x - private readonly struct NegateOperator : IUnaryOperator + internal readonly struct NegateOperator : IUnaryOperator where T : IUnaryNegationOperators { - public static float Invoke(float x) => -x; - public static Vector128 Invoke(Vector128 x) => -x; - public static Vector256 Invoke(Vector256 x) => -x; - public static Vector512 Invoke(Vector512 x) => -x; + public static bool Vectorizable => true; + public static T Invoke(T x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; + public static Vector512 Invoke(Vector512 x) => -x; } /// (x + y) * z - private readonly struct AddMultiplyOperator : ITernaryOperator + internal readonly struct AddMultiplyOperator : ITernaryOperator where T : IAdditionOperators, IMultiplyOperators { - public static float Invoke(float x, float y, float z) => (x + y) * z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; + public static T Invoke(T x, T y, T z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; } /// (x * y) + z - private readonly struct MultiplyAddOperator : ITernaryOperator + internal readonly struct MultiplyAddOperator : ITernaryOperator where T : IAdditionOperators, IMultiplyOperators { - public static float Invoke(float x, float y, float z) => (x * y) + z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; + public static T Invoke(T x, T y, T z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; } /// x - private readonly struct IdentityOperator : IUnaryOperator + internal readonly struct IdentityOperator : IUnaryOperator { - public static float Invoke(float x) => x; - public static Vector128 Invoke(Vector128 x) => x; - public static Vector256 Invoke(Vector256 x) => x; - public static Vector512 Invoke(Vector512 x) => x; + public static bool Vectorizable => true; + public static T Invoke(T x) => x; + public static Vector128 Invoke(Vector128 x) => x; + public static Vector256 Invoke(Vector256 x) => x; + public static Vector512 Invoke(Vector512 x) => x; } /// x * x - private readonly struct SquaredOperator : IUnaryOperator + internal readonly struct SquaredOperator : IUnaryOperator where T : IMultiplyOperators { - public static float Invoke(float x) => x * x; - public static Vector128 Invoke(Vector128 x) => x * x; - public static Vector256 Invoke(Vector256 x) => x * x; - public static Vector512 Invoke(Vector512 x) => x * x; + public static bool Vectorizable => true; + public static T Invoke(T x) => x * x; + public static Vector128 Invoke(Vector128 x) => x * x; + public static Vector256 Invoke(Vector256 x) => x * x; + public static Vector512 Invoke(Vector512 x) => x * x; } - /// MathF.Abs(x) - private readonly struct AbsoluteOperator : IUnaryOperator + /// T.Abs(x) + internal readonly struct AbsoluteOperator : IUnaryOperator where T : INumberBase { - public static float Invoke(float x) => MathF.Abs(x); - public static Vector128 Invoke(Vector128 x) => Vector128.Abs(x); - public static Vector256 Invoke(Vector256 x) => Vector256.Abs(x); - public static Vector512 Invoke(Vector512 x) => Vector512.Abs(x); + public static bool Vectorizable => true; + + public static T Invoke(T x) => T.Abs(x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x) + { + if (typeof(T) == typeof(sbyte) || + typeof(T) == typeof(short) || + typeof(T) == typeof(int) || + typeof(T) == typeof(long) || + typeof(T) == typeof(nint)) + { + // Handle signed integers specially, in order to throw if any attempt is made to + // take the absolute value of the minimum value of the type, which doesn't have + // a positive absolute value representation. + Vector128 negated = -x; + if (Vector128.Equals(x, negated) != Vector128.Zero) + { + ThrowNegateTwosCompOverflow(); + } + + return Vector128.ConditionalSelect(Vector128.LessThan(x, Vector128.Zero), negated, x); + } + + return Vector128.Abs(x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x) + { + if (typeof(T) == typeof(sbyte) || + typeof(T) == typeof(short) || + typeof(T) == typeof(int) || + typeof(T) == typeof(long) || + typeof(T) == typeof(nint)) + { + // Handle signed integers specially, in order to throw if any attempt is made to + // take the absolute value of the minimum value of the type, which doesn't have + // a positive absolute value representation. + Vector256 negated = -x; + if (Vector256.Equals(x, negated) != Vector256.Zero) + { + ThrowNegateTwosCompOverflow(); + } + + return Vector256.ConditionalSelect(Vector256.LessThan(x, Vector256.Zero), negated, x); + } + + return Vector256.Abs(x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x) + { + if (typeof(T) == typeof(sbyte) || + typeof(T) == typeof(short) || + typeof(T) == typeof(int) || + typeof(T) == typeof(long) || + typeof(T) == typeof(nint)) + { + // Handle signed integers specially, in order to throw if any attempt is made to + // take the absolute value of the minimum value of the type, which doesn't have + // a positive absolute value representation. + Vector512 negated = -x; + if (Vector512.Equals(x, negated) != Vector512.Zero) + { + ThrowNegateTwosCompOverflow(); + } + + return Vector512.ConditionalSelect(Vector512.LessThan(x, Vector512.Zero), negated, x); + } + + return Vector512.Abs(x); + } } - /// MathF.Exp(x) - private readonly struct ExpOperator : IUnaryOperator + /// T.Exp(x) + internal readonly struct ExpOperator : IUnaryOperator + where T : IExponentialFunctions { // This code is based on `vrs4_expf` from amd/aocl-libm-ose // Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -10398,10 +11056,15 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) private const double C5 = 0.009676036358193323; private const double C6 = 0.001341000536524434; - public static float Invoke(float x) => MathF.Exp(x); + public static bool Vectorizable => typeof(T) == typeof(float); - public static Vector128 Invoke(Vector128 x) + public static T Invoke(T x) => T.Exp(x); + + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + // Convert x to double precision (Vector128 xl, Vector128 xu) = Vector128.Widen(x); @@ -10473,11 +11136,14 @@ public static Vector128 Invoke(Vector128 x) ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN))); } - return ret; + return ret.As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + // Convert x to double precision (Vector256 xl, Vector256 xu) = Vector256.Widen(x); @@ -10549,11 +11215,14 @@ public static Vector256 Invoke(Vector256 x) ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN))); } - return ret; + return ret.As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + // Convert x to double precision (Vector512 xl, Vector512 xu) = Vector512.Widen(x); @@ -10625,12 +11294,13 @@ public static Vector512 Invoke(Vector512 x) ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN))); } - return ret; + return ret.As(); } } - /// MathF.Cosh(x) - private readonly struct CoshOperator : IUnaryOperator + /// T.Cosh(x) + internal readonly struct CoshOperator : IUnaryOperator + where T : IHyperbolicFunctions { // This code is based on `vrs4_coshf` from amd/aocl-libm-ose // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -10657,32 +11327,44 @@ public static Vector512 Invoke(Vector512 x) private const float HALFV = 1.0000138f; private const float INVV2 = 0.24999309f; - public static float Invoke(float x) => MathF.Cosh(x); + public static bool Vectorizable => typeof(T) == typeof(float); + + public static T Invoke(T x) => T.Cosh(x); - public static Vector128 Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + Vector128 y = Vector128.Abs(x); - Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); - return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z)); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + return (Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z))).As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + Vector256 y = Vector256.Abs(x); - Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); - return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z)); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + return (Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z))).As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + Vector512 y = Vector512.Abs(x); - Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); - return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z)); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + return (Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z))).As(); } } - /// MathF.Sinh(x) - private readonly struct SinhOperator : IUnaryOperator + /// T.Sinh(x) + internal readonly struct SinhOperator : IUnaryOperator + where T : IHyperbolicFunctions { // Same as cosh, but with `z -` rather than `z +`, and with the sign // flipped on the result based on the sign of the input. @@ -10692,38 +11374,50 @@ public static Vector512 Invoke(Vector512 x) private const float HALFV = 1.0000138f; private const float INVV2 = 0.24999309f; - public static float Invoke(float x) => MathF.Sinh(x); + public static bool Vectorizable => typeof(T) == typeof(float); + + public static T Invoke(T x) => T.Sinh(x); - public static Vector128 Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + Vector128 y = Vector128.Abs(x); - Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); Vector128 result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z)); Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); - return (sign ^ result.AsUInt32()).AsSingle(); + return (sign ^ result.AsUInt32()).AsSingle().As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + Vector256 y = Vector256.Abs(x); - Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); Vector256 result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z)); Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); - return (sign ^ result.AsUInt32()).AsSingle(); + return (sign ^ result.AsUInt32()).AsSingle().As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + Vector512 y = Vector512.Abs(x); - Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); Vector512 result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z)); Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); - return (sign ^ result.AsUInt32()).AsSingle(); + return (sign ^ result.AsUInt32()).AsSingle().As(); } } - /// MathF.Tanh(x) - private readonly struct TanhOperator : IUnaryOperator + /// T.Tanh(x) + internal readonly struct TanhOperator : IUnaryOperator + where T : IHyperbolicFunctions { // This code is based on `vrs4_tanhf` from amd/aocl-libm-ose // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -10746,35 +11440,47 @@ public static Vector512 Invoke(Vector512 x) private const uint SIGN_MASK = 0x7FFFFFFF; - public static float Invoke(float x) => MathF.Tanh(x); + public static bool Vectorizable => typeof(T) == typeof(float); - public static Vector128 Invoke(Vector128 x) + public static T Invoke(T x) => T.Tanh(x); + + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + Vector128 y = Vector128.Abs(x); - Vector128 z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f); + Vector128 z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f); Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); - return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle(); + return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle().As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + Vector256 y = Vector256.Abs(x); - Vector256 z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f); + Vector256 z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f); Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); - return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle(); + return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle().As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + Vector512 y = Vector512.Abs(x); - Vector512 z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f); + Vector512 z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f); Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); - return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle(); + return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle().As(); } } - /// MathF.Log(x) - private readonly struct LogOperator : IUnaryOperator + /// T.Log(x) + internal readonly struct LogOperator : IUnaryOperator + where T : ILogarithmicFunctions { // This code is based on `vrs4_logf` from amd/aocl-libm-ose // Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. @@ -10847,10 +11553,15 @@ public static Vector512 Invoke(Vector512 x) private const float C9 = 0.14401625f; private const float C10 = -0.13657966f; - public static float Invoke(float x) => MathF.Log(x); + public static bool Vectorizable => typeof(T) == typeof(float); + + public static T Invoke(T x) => T.Log(x); - public static Vector128 Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + Vector128 specialResult = x; // x is subnormal or infinity or NaN @@ -10915,11 +11626,14 @@ public static Vector128 Invoke(Vector128 x) specialMask.AsSingle(), specialResult, n * Vector128.Create(V_LN2) + q - ); + ).As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + Vector256 specialResult = x; // x is subnormal or infinity or NaN @@ -10984,11 +11698,14 @@ public static Vector256 Invoke(Vector256 x) specialMask.AsSingle(), specialResult, n * Vector256.Create(V_LN2) + q - ); + ).As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + Vector512 specialResult = x; // x is subnormal or infinity or NaN @@ -11053,12 +11770,13 @@ public static Vector512 Invoke(Vector512 x) specialMask.AsSingle(), specialResult, n * Vector512.Create(V_LN2) + q - ); + ).As(); } } - /// MathF.Log2(x) - private readonly struct Log2Operator : IUnaryOperator + /// T.Log2(x) + internal readonly struct Log2Operator : IUnaryOperator + where T : ILogarithmicFunctions { // This code is based on `vrs4_log2f` from amd/aocl-libm-ose // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -11126,10 +11844,15 @@ public static Vector512 Invoke(Vector512 x) private const float C8 = -0.22616665f; private const float C9 = 0.21228963f; - public static float Invoke(float x) => MathF.Log2(x); + public static bool Vectorizable => typeof(T) == typeof(float); + + public static T Invoke(T x) => T.Log2(x); - public static Vector128 Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector128 x = t.AsSingle(); + Vector128 specialResult = x; // x is subnormal or infinity or NaN @@ -11194,11 +11917,14 @@ public static Vector128 Invoke(Vector128 x) specialMask.AsSingle(), specialResult, n + poly - ); + ).As(); } - public static Vector256 Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector256 x = t.AsSingle(); + Vector256 specialResult = x; // x is subnormal or infinity or NaN @@ -11263,11 +11989,14 @@ public static Vector256 Invoke(Vector256 x) specialMask.AsSingle(), specialResult, n + poly - ); + ).As(); } - public static Vector512 Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 t) { + Debug.Assert(typeof(T) == typeof(float)); + Vector512 x = t.AsSingle(); + Vector512 specialResult = x; // x is subnormal or infinity or NaN @@ -11332,108 +12061,112 @@ public static Vector512 Invoke(Vector512 x) specialMask.AsSingle(), specialResult, n + poly - ); + ).As(); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) - { - if (Sse41.IsSupported) - return Sse41.BlendVariable(left, right, ~mask); - - return Vector128.ConditionalSelect(mask, left, right); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { if (Sse41.IsSupported) - return Sse41.BlendVariable(left, right, ~mask); + { + if (typeof(T) == typeof(byte)) return Sse41.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return Sse41.BlendVariable(left.AsSByte(), right.AsSByte(), (~mask).AsSByte()).As(); + if (typeof(T) == typeof(ushort)) return Sse41.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As(); + if (typeof(T) == typeof(short)) return Sse41.BlendVariable(left.AsInt16(), right.AsInt16(), (~mask).AsInt16()).As(); + if (typeof(T) == typeof(uint)) return Sse41.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As(); + if (typeof(T) == typeof(int)) return Sse41.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As(); + if (typeof(T) == typeof(ulong)) return Sse41.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As(); + if (typeof(T) == typeof(long)) return Sse41.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As(); + if (typeof(T) == typeof(float)) return Sse41.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As(); + if (typeof(T) == typeof(double)) return Sse41.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As(); + } return Vector128.ConditionalSelect(mask, left, right); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) - { - if (Avx2.IsSupported) - return Avx2.BlendVariable(left, right, ~mask); - - return Vector256.ConditionalSelect(mask, left, right); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) { if (Avx2.IsSupported) - return Avx2.BlendVariable(left, right, ~mask); + { + if (typeof(T) == typeof(byte)) return Avx2.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As(); + if (typeof(T) == typeof(sbyte)) return Avx2.BlendVariable(left.AsSByte(), right.AsSByte(), (~mask).AsSByte()).As(); + if (typeof(T) == typeof(ushort)) return Avx2.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As(); + if (typeof(T) == typeof(short)) return Avx2.BlendVariable(left.AsInt16(), right.AsInt16(), (~mask).AsInt16()).As(); + if (typeof(T) == typeof(uint)) return Avx2.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As(); + if (typeof(T) == typeof(int)) return Avx2.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As(); + if (typeof(T) == typeof(ulong)) return Avx2.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As(); + if (typeof(T) == typeof(long)) return Avx2.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As(); + if (typeof(T) == typeof(float)) return Avx2.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As(); + if (typeof(T) == typeof(double)) return Avx2.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As(); + } return Vector256.ConditionalSelect(mask, left, right); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) - { - if (Avx512F.IsSupported) - return Avx512F.BlendVariable(left, right, ~mask); - - return Vector512.ConditionalSelect(mask, left, right); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) { if (Avx512F.IsSupported) - return Avx512F.BlendVariable(left, right, ~mask); + { + if (typeof(T) == typeof(uint)) return Avx512F.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As(); + if (typeof(T) == typeof(int)) return Avx512F.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As(); + if (typeof(T) == typeof(ulong)) return Avx512F.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As(); + if (typeof(T) == typeof(long)) return Avx512F.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As(); + if (typeof(T) == typeof(float)) return Avx512F.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As(); + if (typeof(T) == typeof(double)) return Avx512F.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As(); + } return Vector512.ConditionalSelect(mask, left, right); } /// 1f / (1f + MathF.Exp(-x)) - private readonly struct SigmoidOperator : IUnaryOperator + internal readonly struct SigmoidOperator : IUnaryOperator where T : IExponentialFunctions { - public static float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); - public static Vector128 Invoke(Vector128 x) => Vector128.Create(1f) / (Vector128.Create(1f) + ExpOperator.Invoke(-x)); - public static Vector256 Invoke(Vector256 x) => Vector256.Create(1f) / (Vector256.Create(1f) + ExpOperator.Invoke(-x)); - public static Vector512 Invoke(Vector512 x) => Vector512.Create(1f) / (Vector512.Create(1f) + ExpOperator.Invoke(-x)); + public static bool Vectorizable => typeof(T) == typeof(float); + public static T Invoke(T x) => T.One / (T.One + T.Exp(-x)); + public static Vector128 Invoke(Vector128 x) => Vector128.Create(T.One) / (Vector128.Create(T.One) + ExpOperator.Invoke(-x)); + public static Vector256 Invoke(Vector256 x) => Vector256.Create(T.One) / (Vector256.Create(T.One) + ExpOperator.Invoke(-x)); + public static Vector512 Invoke(Vector512 x) => Vector512.Create(T.One) / (Vector512.Create(T.One) + ExpOperator.Invoke(-x)); } /// Operator that takes one input value and returns a single value. - private interface IUnaryOperator + private interface IUnaryOperator { - static abstract float Invoke(float x); - static abstract Vector128 Invoke(Vector128 x); - static abstract Vector256 Invoke(Vector256 x); - static abstract Vector512 Invoke(Vector512 x); + static abstract bool Vectorizable { get; } + static abstract T Invoke(T x); + static abstract Vector128 Invoke(Vector128 x); + static abstract Vector256 Invoke(Vector256 x); + static abstract Vector512 Invoke(Vector512 x); } /// Operator that takes two input values and returns a single value. - private interface IBinaryOperator + private interface IBinaryOperator { - static abstract float Invoke(float x, float y); - static abstract Vector128 Invoke(Vector128 x, Vector128 y); - static abstract Vector256 Invoke(Vector256 x, Vector256 y); - static abstract Vector512 Invoke(Vector512 x, Vector512 y); + static abstract T Invoke(T x, T y); + static abstract Vector128 Invoke(Vector128 x, Vector128 y); + static abstract Vector256 Invoke(Vector256 x, Vector256 y); + static abstract Vector512 Invoke(Vector512 x, Vector512 y); } - /// that specializes horizontal aggregation of all elements in a vector. - private interface IAggregationOperator : IBinaryOperator + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator { - static abstract float Invoke(Vector128 x); - static abstract float Invoke(Vector256 x); - static abstract float Invoke(Vector512 x); + static abstract T Invoke(Vector128 x); + static abstract T Invoke(Vector256 x); + static abstract T Invoke(Vector512 x); - static virtual float IdentityValue => throw new NotSupportedException(); + static virtual T IdentityValue => throw new NotSupportedException(); } /// Operator that takes three input values and returns a single value. - private interface ITernaryOperator + private interface ITernaryOperator { - static abstract float Invoke(float x, float y, float z); - static abstract Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z); - static abstract Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z); - static abstract Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z); + static abstract T Invoke(T x, T y, T z); + static abstract Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z); + static abstract Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z); + static abstract Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z); } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netstandard/TensorPrimitives.Single.netstandard.cs similarity index 98% rename from src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs rename to src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netstandard/TensorPrimitives.Single.netstandard.cs index dcb7a100baac31..fe32027099c7e3 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netstandard/TensorPrimitives.Single.netstandard.cs @@ -1507,7 +1507,7 @@ static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuin private static void InvokeSpanScalarIntoSpan( ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) where TBinaryOperator : struct, IBinaryOperator => - InvokeSpanScalarIntoSpan(x, y, destination, default, op); + InvokeSpanScalarIntoSpan(x, y, destination, default, op); /// /// Performs an element-wise operation on and , @@ -2901,7 +2901,7 @@ private static Vector CreateAlignmentMaskSingleVector(int count) Debug.Assert(Vector.Count is 4 or 8 or 16); return AsVector( - ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17)), (count * 16)); } @@ -2914,12 +2914,12 @@ private static Vector CreateRemainderMaskSingleVector(int count) Debug.Assert(Vector.Count is 4 or 8 or 16); return AsVector( - ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17)), (count * 16) + (16 - Vector.Count)); } /// x + y - private readonly struct AddOperator : IAggregationOperator + private readonly struct AddOperator_Single : IAggregationOperator { public float Invoke(float x, float y) => x + y; public Vector Invoke(Vector x, Vector y) => x + y; @@ -2927,14 +2927,14 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16 } /// x - y - private readonly struct SubtractOperator : IBinaryOperator + private readonly struct SubtractOperator_Single : IBinaryOperator { public float Invoke(float x, float y) => x - y; public Vector Invoke(Vector x, Vector y) => x - y; } /// (x - y) * (x - y) - private readonly struct SubtractSquaredOperator : IBinaryOperator + private readonly struct SubtractSquaredOperator_Single : IBinaryOperator { public float Invoke(float x, float y) { @@ -2950,7 +2950,7 @@ public Vector Invoke(Vector x, Vector y) } /// x * y - private readonly struct MultiplyOperator : IAggregationOperator + private readonly struct MultiplyOperator_Single : IAggregationOperator { public float Invoke(float x, float y) => x * y; public Vector Invoke(Vector x, Vector y) => x * y; @@ -2958,7 +2958,7 @@ public Vector Invoke(Vector x, Vector y) } /// x / y - private readonly struct DivideOperator : IBinaryOperator + private readonly struct DivideOperator_Single : IBinaryOperator { public float Invoke(float x, float y) => x / y; public Vector Invoke(Vector x, Vector y) => x / y; @@ -2972,7 +2972,7 @@ private interface IIndexOfOperator } /// Returns the index of MathF.Max(x, y) - private readonly struct IndexOfMaxOperator : IIndexOfOperator + private readonly struct IndexOfMaxOperator_Single : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public int Invoke(Vector result, Vector resultIndex) @@ -3037,7 +3037,7 @@ public int Invoke(ref float result, float current, int resultIndex, int curIndex } } - private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + private readonly struct IndexOfMaxMagnitudeOperator_Single : IIndexOfOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public int Invoke(ref float result, float current, int resultIndex, int curIndex) @@ -3107,7 +3107,7 @@ public void Invoke(ref Vector result, Vector current, ref Vector result, Vector current, ref Vector result, Vector current, ref VectorMathF.Max(x, y) (but without guaranteed NaN propagation) - private readonly struct MaxOperator : IBinaryOperator + private readonly struct MaxOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) => @@ -3258,7 +3258,7 @@ public Vector Invoke(Vector x, Vector y) => } /// MathF.Max(x, y) - private readonly struct MaxPropagateNaNOperator : IBinaryOperator + private readonly struct MaxPropagateNaNOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) => MathF.Max(x, y); @@ -3275,7 +3275,7 @@ public Vector Invoke(Vector x, Vector y) => } /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) - private readonly struct MaxMagnitudeOperator : IBinaryOperator + private readonly struct MaxMagnitudeOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) @@ -3299,7 +3299,7 @@ public Vector Invoke(Vector x, Vector y) } /// Operator to get x or y based on which has the larger MathF.Abs - private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + private readonly struct MaxMagnitudePropagateNaNOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) @@ -3324,7 +3324,7 @@ public Vector Invoke(Vector x, Vector y) } /// MathF.Min(x, y) (but NaNs may not be propagated) - private readonly struct MinOperator : IBinaryOperator + private readonly struct MinOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) => @@ -3340,7 +3340,7 @@ public Vector Invoke(Vector x, Vector y) => } /// MathF.Min(x, y) - private readonly struct MinPropagateNaNOperator : IBinaryOperator + private readonly struct MinPropagateNaNOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) => MathF.Min(x, y); @@ -3357,7 +3357,7 @@ public Vector Invoke(Vector x, Vector y) => } /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) - private readonly struct MinMagnitudeOperator : IBinaryOperator + private readonly struct MinMagnitudeOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) @@ -3381,7 +3381,7 @@ public Vector Invoke(Vector x, Vector y) } /// Operator to get x or y based on which has the smaller MathF.Abs - private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + private readonly struct MinMagnitudePropagateNaNOperator_Single : IBinaryOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) @@ -3407,7 +3407,7 @@ public Vector Invoke(Vector x, Vector y) } /// -x - private readonly struct NegateOperator : IUnaryOperator + private readonly struct NegateOperator_Single : IUnaryOperator { public bool CanVectorize => true; public float Invoke(float x) => -x; @@ -3415,21 +3415,21 @@ public Vector Invoke(Vector x, Vector y) } /// (x + y) * z - private readonly struct AddMultiplyOperator : ITernaryOperator + private readonly struct AddMultiplyOperator_Single : ITernaryOperator { public float Invoke(float x, float y, float z) => (x + y) * z; public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; } /// (x * y) + z - private readonly struct MultiplyAddOperator : ITernaryOperator + private readonly struct MultiplyAddOperator_Single : ITernaryOperator { public float Invoke(float x, float y, float z) => (x * y) + z; public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; } /// x - private readonly struct IdentityOperator : IUnaryOperator + private readonly struct IdentityOperator_Single : IUnaryOperator { public bool CanVectorize => true; public float Invoke(float x) => x; @@ -3437,7 +3437,7 @@ public Vector Invoke(Vector x, Vector y) } /// x * x - private readonly struct SquaredOperator : IUnaryOperator + private readonly struct SquaredOperator_Single : IUnaryOperator { public bool CanVectorize => true; public float Invoke(float x) => x * x; @@ -3445,7 +3445,7 @@ public Vector Invoke(Vector x, Vector y) } /// MathF.Abs(x) - private readonly struct AbsoluteOperator : IUnaryOperator + private readonly struct AbsoluteOperator_Single : IUnaryOperator { public bool CanVectorize => true; public float Invoke(float x) => MathF.Abs(x); @@ -3453,7 +3453,7 @@ public Vector Invoke(Vector x, Vector y) } /// MathF.Exp(x) - private readonly struct ExpOperator : IUnaryOperator + private readonly struct ExpOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => MathF.Exp(x); @@ -3463,7 +3463,7 @@ public Vector Invoke(Vector x) => } /// MathF.Sinh(x) - private readonly struct SinhOperator : IUnaryOperator + private readonly struct SinhOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => MathF.Sinh(x); @@ -3473,7 +3473,7 @@ public Vector Invoke(Vector x) => } /// MathF.Cosh(x) - private readonly struct CoshOperator : IUnaryOperator + private readonly struct CoshOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => MathF.Cosh(x); @@ -3483,7 +3483,7 @@ public Vector Invoke(Vector x) => } /// MathF.Tanh(x) - private readonly struct TanhOperator : IUnaryOperator + private readonly struct TanhOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => MathF.Tanh(x); @@ -3493,7 +3493,7 @@ public Vector Invoke(Vector x) => } /// MathF.Log(x) - private readonly struct LogOperator : IUnaryOperator + private readonly struct LogOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => MathF.Log(x); @@ -3503,7 +3503,7 @@ public Vector Invoke(Vector x) => } /// MathF.Log2(x) - private readonly struct Log2Operator : IUnaryOperator + private readonly struct Log2Operator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => Log2(x); @@ -3513,7 +3513,7 @@ public Vector Invoke(Vector x) => } /// 1f / (1f + MathF.Exp(-x)) - private readonly struct SigmoidOperator : IUnaryOperator + private readonly struct SigmoidOperator_Single : IUnaryOperator { public bool CanVectorize => false; public float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); diff --git a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs new file mode 100644 index 00000000000000..d6b5eef63d9dae --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; + +namespace System.Numerics.Tensors.Tests +{ + internal static class Helpers + { + public static IEnumerable TensorLengthsIncluding0 => Enumerable.Range(0, 257); + + public static IEnumerable TensorLengths => Enumerable.Range(1, 256); + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj index be4a103d7256ce..11149747b7a1d8 100644 --- a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj +++ b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj @@ -6,11 +6,14 @@ + + - + + @@ -18,4 +21,8 @@ + + + + \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.ConvertTo.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.ConvertTo.cs new file mode 100644 index 00000000000000..f8f6c07e2174ee --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.ConvertTo.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using Xunit; + +namespace System.Numerics.Tensors.Tests +{ + public class ConvertToHalfTests + { + private readonly Random _random = new Random(42); + + [Fact] + public void ConvertToHalf() + { + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory source = CreateAndFillSingleTensor(tensorLength); + foreach (int destLength in new[] { source.Length, source.Length + 1 }) + { + using BoundedMemory destination = BoundedMemory.Allocate(destLength); + destination.Span.Fill(Half.Zero); + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + + if (destination.Length > source.Length) + { + for (int i = source.Length; i < destination.Length; i++) + { + Assert.Equal(Half.Zero, destination[i]); + } + } + } + }); + } + + [Fact] + public void ConvertToHalf_SpecialValues() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory source = CreateAndFillSingleTensor(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + // NaN, infinities, and 0s + source[_random.Next(source.Length)] = float.NaN; + source[_random.Next(source.Length)] = float.PositiveInfinity; + source[_random.Next(source.Length)] = float.NegativeInfinity; + source[_random.Next(source.Length)] = 0; + source[_random.Next(source.Length)] = float.NegativeZero; + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + }); + } + + [Fact] + public void ConvertToHalf_ThrowsForTooShortDestination() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory source = CreateAndFillSingleTensor(tensorLength); + Half[] destination = new Half[source.Length - 1]; + + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToHalf(source, destination)); + }); + } + + [Fact] + public void ConvertToSingle() + { + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)_random.NextSingle(); + } + + foreach (int destLength in new[] { source.Length, source.Length + 1 }) + { + using BoundedMemory destination = CreateSingleTensor(destLength); + destination.Span.Fill(0f); + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + + if (destination.Length > source.Length) + { + for (int i = source.Length; i < destination.Length; i++) + { + Assert.Equal(0f, destination[i]); + } + } + } + }); + } + + [Fact] + public void ConvertToSingle_SpecialValues() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)_random.NextSingle(); + } + + using BoundedMemory destination = CreateSingleTensor(tensorLength); + + // NaN, infinities, and 0s + source[_random.Next(source.Length)] = Half.NaN; + source[_random.Next(source.Length)] = Half.PositiveInfinity; + source[_random.Next(source.Length)] = Half.NegativeInfinity; + source[_random.Next(source.Length)] = Half.Zero; + source[_random.Next(source.Length)] = Half.NegativeZero; + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + }); + } + + [Fact] + public void ConvertToSingle_ThrowsForTooShortDestination() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + Half[] source = new Half[tensorLength]; + using BoundedMemory destination = CreateSingleTensor(source.Length - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToSingle(source, destination)); + }); + } + + public BoundedMemory CreateSingleTensor(int size) => BoundedMemory.Allocate(size); + + public BoundedMemory CreateAndFillSingleTensor(int size) + { + BoundedMemory tensor = CreateSingleTensor(size); + Span span = tensor; + for (int i = 0; i < span.Length; i++) + { + span[i] = (float)((_random.NextDouble() * 2) - 1); // For testing purposes, get a mix of negative and positive values. + } + return tensor; + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs new file mode 100644 index 00000000000000..de0269ca889246 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -0,0 +1,295 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Xunit; +using Xunit.Sdk; + +namespace System.Numerics.Tensors.Tests +{ + public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } + public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } + public class HalfGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests + { + protected override void AssertEqualTolerance(Half expected, Half actual) => AssertEqualTolerance(expected, actual, Half.CreateTruncating(0.001)); + } + public class NFloatGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } + + public class SByteGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class Int16GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class Int32GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class Int64GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class IntPtrGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class Int128GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + + public class ByteGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt16GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class CharGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt32GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt64GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UIntPtrGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt128GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + + public unsafe abstract class GenericFloatingPointNumberTensorPrimitivesTests : GenericNumberTensorPrimitivesTests + where T : unmanaged, IFloatingPointIeee754, IMinMaxValue + { + protected override T Cosh(T x) => T.Cosh(x); + protected override void Cosh(ReadOnlySpan x, Span destination) => TensorPrimitives.Cosh(x, destination); + protected override T CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.CosineSimilarity(x, y); + protected override T Distance(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.Distance(x, y); + protected override void Exp(ReadOnlySpan x, Span destination) => TensorPrimitives.Exp(x, destination); + protected override T Exp(T x) => T.Exp(x); + protected override T Log(T x) => T.Log(x); + protected override void Log(ReadOnlySpan x, Span destination) => TensorPrimitives.Log(x, destination); + protected override T Log2(T x) => T.Log2(x); + protected override void Log2(ReadOnlySpan x, Span destination) => TensorPrimitives.Log2(x, destination); + protected override T Norm(ReadOnlySpan x) => TensorPrimitives.Norm(x); + protected override void Sigmoid(ReadOnlySpan x, Span destination) => TensorPrimitives.Sigmoid(x, destination); + protected override void Sinh(ReadOnlySpan x, Span destination) => TensorPrimitives.Sinh(x, destination); + protected override T Sinh(T x) => T.Sinh(x); + protected override void SoftMax(ReadOnlySpan x, Span destination) => TensorPrimitives.SoftMax(x, destination); + protected override T Sqrt(T x) => T.Sqrt(x); + protected override void Tanh(ReadOnlySpan x, Span destination) => TensorPrimitives.Tanh(x, destination); + protected override T Tanh(T x) => T.Tanh(x); + + protected override T NaN => T.NaN; + + protected override T NextRandom() => T.CreateTruncating((Random.NextDouble() * 2) - 1); // For testing purposes, get a mix of negative and positive values. + + protected override IEnumerable GetSpecialValues() + { + // NaN + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0xFFC0_0000)); // -qNaN / float.NaN + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0xFFFF_FFFF)); // -qNaN / all-bits-set + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x7FC0_0000)); // +qNaN + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0xFFA0_0000)); // -sNaN + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x7FA0_0000)); // +sNaN + + // +Infinity, -Infinity + yield return T.CreateTruncating(float.PositiveInfinity); + yield return T.CreateTruncating(float.NegativeInfinity); + + // +Zero, -Zero + yield return T.Zero; + yield return T.NegativeZero; + + // Subnormals + yield return T.Epsilon; + yield return -T.Epsilon; + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x007F_FFFF)); + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x807F_FFFF)); + + // Normals + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x0080_0000)); + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x8080_0000)); + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x7F7F_FFFF)); // MaxValue + yield return T.CreateTruncating(BitConverter.UInt32BitsToSingle(0xFF7F_FFFF)); // MinValue + } + + protected override void SetSpecialValues(Span x, Span y) + { + int pos; + + // NaNs + pos = Random.Next(x.Length); + x[pos] = T.NaN; + y[pos] = T.CreateTruncating(BitConverter.UInt32BitsToSingle(0x7FC0_0000)); + + // +Infinity, -Infinity + pos = Random.Next(x.Length); + x[pos] = T.PositiveInfinity; + y[pos] = T.NegativeInfinity; + + // +Zero, -Zero + pos = Random.Next(x.Length); + x[pos] = T.Zero; + y[pos] = T.NegativeZero; + + // +Epsilon, -Epsilon + pos = Random.Next(x.Length); + x[pos] = T.Epsilon; + y[pos] = -T.Epsilon; + + // Same magnitude, opposite sign + pos = Random.Next(x.Length); + x[pos] = T.CreateTruncating(5); + y[pos] = T.CreateTruncating(-5); + } + } + + public unsafe abstract class GenericSignedIntegerTensorPrimitivesTests : GenericIntegerTensorPrimitivesTests + where T : unmanaged, IBinaryInteger, IMinMaxValue + { + [Fact] + public void Abs_MinValue_Throws() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + FillTensor(x.Span, T.MinValue); + x[^1] = T.MinValue; + + Assert.Throws(() => Abs(x, destination)); + }); + } + + [Fact] + public void SumOfMagnitudes_MinValue_Throws() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + FillTensor(x.Span, T.MinValue); + x[^1] = T.MinValue; + + Assert.Throws(() => SumOfMagnitudes(x)); + }); + } + } + + public unsafe abstract class GenericIntegerTensorPrimitivesTests : GenericNumberTensorPrimitivesTests + where T : unmanaged, IBinaryInteger, IMinMaxValue + { + [Fact] + public void Divide_TwoTensors_ByZero_Throws() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + FillTensor(y.Span, T.Zero); + y[^1] = T.Zero; + + Exception e = Record.Exception(() => Divide(x, y, destination)); + Assert.True(e is DivideByZeroException or ArgumentOutOfRangeException); // TODO https://github.com/dotnet/runtime/issues/94593: Fix exception type + }); + } + + [Fact] + public void Divide_TensorScalar_ByZero_Throw() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + Exception e = Record.Exception(() => Divide(x, T.Zero, destination)); + Assert.True(e is DivideByZeroException or ArgumentOutOfRangeException); // TODO https://github.com/dotnet/runtime/issues/94593: Fix exception type + }); + } + } + + public unsafe abstract class GenericNumberTensorPrimitivesTests : TensorPrimitivesTests + where T : unmanaged, INumber, IMinMaxValue + { + protected override void Abs(ReadOnlySpan x, Span destination) => TensorPrimitives.Abs(x, destination); + protected override T Abs(T x) => T.Abs(x); + protected override void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Add(x, y, destination); + protected override void Add(ReadOnlySpan x, T y, Span destination) => TensorPrimitives.Add(x, y, destination); + protected override T Add(T x, T y) => x + y; + protected override void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, T z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override void AddMultiply(ReadOnlySpan x, T y, ReadOnlySpan z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override T AddMultiply(T x, T y, T z) => (x + y) * z; + protected override void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Divide(x, y, destination); + protected override void Divide(ReadOnlySpan x, T y, Span destination) => TensorPrimitives.Divide(x, y, destination); + protected override T Divide(T x, T y) => x / y; + protected override T Dot(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.Dot(x, y); + protected override T Max(ReadOnlySpan x) => TensorPrimitives.Max(x); + protected override void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Max(x, y, destination); + protected override T Max(T x, T y) => T.Max(x, y); + protected override T MaxMagnitude(ReadOnlySpan x) => TensorPrimitives.MaxMagnitude(x); + protected override void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.MaxMagnitude(x, y, destination); + protected override T MaxMagnitude(T x, T y) => T.MaxMagnitude(x, y); + protected override T Min(ReadOnlySpan x) => TensorPrimitives.Min(x); + protected override void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Min(x, y, destination); + protected override T Min(T x, T y) => T.Min(x, y); + protected override T MinMagnitude(ReadOnlySpan x) => TensorPrimitives.MinMagnitude(x); + protected override void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.MinMagnitude(x, y, destination); + protected override T MinMagnitude(T x, T y) => T.MinMagnitude(x, y); + protected override void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Multiply(x, y, destination); + protected override void Multiply(ReadOnlySpan x, T y, Span destination) => TensorPrimitives.Multiply(x, y, destination); + protected override T Multiply(T x, T y) => x * y; + protected override void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, T z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void MultiplyAdd(ReadOnlySpan x, T y, ReadOnlySpan z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void Negate(ReadOnlySpan x, Span destination) => TensorPrimitives.Negate(x, destination); + protected override T Product(ReadOnlySpan x) => TensorPrimitives.Product(x); + protected override T ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.ProductOfSums(x, y); + protected override T ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.ProductOfDifferences(x, y); + protected override void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Subtract(x, y, destination); + protected override void Subtract(ReadOnlySpan x, T y, Span destination) => TensorPrimitives.Subtract(x, y, destination); + protected override T Subtract(T x, T y) => x - y; + protected override T Sum(ReadOnlySpan x) => TensorPrimitives.Sum(x); + protected override T SumOfMagnitudes(ReadOnlySpan x) => TensorPrimitives.SumOfMagnitudes(x); + protected override T SumOfSquares(ReadOnlySpan x) => TensorPrimitives.SumOfSquares(x); + + protected override T ConvertFromSingle(float f) => T.CreateTruncating(f); + protected override bool IsFloatingPoint => typeof(T) == typeof(Half) || base.IsFloatingPoint; + + protected override T NextRandom() + { + T value = default; + Random.NextBytes(MemoryMarshal.AsBytes(new Span(ref value))); + return value; + } + + protected override T NegativeZero => -T.Zero; + protected override T Zero => T.Zero; + protected override T One => T.One; + protected override T NegativeOne => -T.One; + protected override T MinValue => T.MinValue; + + protected override IEnumerable<(int Length, T Element)> VectorLengthAndIteratedRange(T min, T max, T increment) + { + foreach (int length in new[] { 4, 8, 16 }) + { + for (T f = min; f <= max; f += increment) + { + yield return (length, f); + } + } + } + + protected override void AssertEqualTolerance(T expected, T actual) => AssertEqualTolerance(expected, actual, T.CreateTruncating(0.0001)); + + protected override void AssertEqualTolerance(T expected, T actual, T tolerance) + { + T diff = T.Abs(expected - actual); + if (diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance) + { + throw EqualException.ForMismatchedValues(expected, actual); + } + } + + protected override T Cosh(T x) => throw new NotSupportedException(); + protected override void Cosh(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) => throw new NotSupportedException(); + protected override T Distance(ReadOnlySpan x, ReadOnlySpan y) => throw new NotSupportedException(); + protected override void Exp(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Exp(T x) => throw new NotSupportedException(); + protected override T Log(T x) => throw new NotSupportedException(); + protected override void Log(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Log2(T x) => throw new NotSupportedException(); + protected override void Log2(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Norm(ReadOnlySpan x) => throw new NotSupportedException( ); + protected override void Sigmoid(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override void Sinh(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Sinh(T x) => throw new NotSupportedException(); + protected override void SoftMax(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Sqrt(T x) => throw new NotSupportedException(); + protected override void Tanh(ReadOnlySpan x, Span destination) => throw new NotSupportedException(); + protected override T Tanh(T x) => throw new NotSupportedException(); + protected override T NaN => throw new NotSupportedException(); + protected override IEnumerable GetSpecialValues() => Enumerable.Empty(); + protected override void SetSpecialValues(Span x, Span y) { } + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs new file mode 100644 index 00000000000000..50e3b3dae77d42 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs @@ -0,0 +1,376 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Xunit; +using Xunit.Sdk; + +namespace System.Numerics.Tensors.Tests +{ + public unsafe class NonGenericSingleTensorPrimitivesTests : TensorPrimitivesTests + { + protected override void Abs(ReadOnlySpan x, Span destination) => TensorPrimitives.Abs(x, destination); + protected override float Abs(float x) => MathF.Abs(x); + protected override void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Add(x, y, destination); + protected override void Add(ReadOnlySpan x, float y, Span destination) => TensorPrimitives.Add(x, y, destination); + protected override float Add(float x, float y) => x + y; + protected override void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) => TensorPrimitives.AddMultiply(x, y, z, destination); + protected override float AddMultiply(float x, float y, float z) => (x + y) * z; + protected override void Cosh(ReadOnlySpan x, Span destination) => TensorPrimitives.Cosh(x, destination); + protected override float Cosh(float x) => MathF.Cosh(x); + protected override float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.CosineSimilarity(x, y); + protected override float Distance(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.Distance(x, y); + protected override void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Divide(x, y, destination); + protected override void Divide(ReadOnlySpan x, float y, Span destination) => TensorPrimitives.Divide(x, y, destination); + protected override float Divide(float x, float y) => x / y; + protected override float Dot(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.Dot(x, y); + protected override void Exp(ReadOnlySpan x, Span destination) => TensorPrimitives.Exp(x, destination); + protected override float Exp(float x) => MathF.Exp(x); + protected override float Log(float x) => MathF.Log(x); + protected override void Log(ReadOnlySpan x, Span destination) => TensorPrimitives.Log(x, destination); + protected override float Log2(float x) => MathF.Log(x, 2); + protected override void Log2(ReadOnlySpan x, Span destination) => TensorPrimitives.Log2(x, destination); + protected override float Max(ReadOnlySpan x) => TensorPrimitives.Max(x); + protected override void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Max(x, y, destination); + protected override float Max(float x, float y) => MathF.Max(x, y); + protected override float MaxMagnitude(ReadOnlySpan x) => TensorPrimitives.MaxMagnitude(x); + protected override void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.MaxMagnitude(x, y, destination); + protected override float MaxMagnitude(float x, float y) + { + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax > ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x >= 0) ? x : y; + } + protected override float Min(ReadOnlySpan x) => TensorPrimitives.Min(x); + protected override void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Min(x, y, destination); + protected override float Min(float x, float y) => MathF.Min(x, y); + protected override float MinMagnitude(ReadOnlySpan x) => TensorPrimitives.MinMagnitude(x); + protected override void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.MinMagnitude(x, y, destination); + protected override float MinMagnitude(float x, float y) + { + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax < ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x < 0) ? x : y; + } + protected override void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Multiply(x, y, destination); + protected override void Multiply(ReadOnlySpan x, float y, Span destination) => TensorPrimitives.Multiply(x, y, destination); + protected override float Multiply(float x, float y) => x * y; + protected override void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) => TensorPrimitives.MultiplyAdd(x, y, z, destination); + protected override void Negate(ReadOnlySpan x, Span destination) => TensorPrimitives.Negate(x, destination); + protected override float Norm(ReadOnlySpan x) => TensorPrimitives.Norm(x); + protected override float Product(ReadOnlySpan x) => TensorPrimitives.Product(x); + protected override float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.ProductOfSums(x, y); + protected override float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) => TensorPrimitives.ProductOfDifferences(x, y); + protected override void Sigmoid(ReadOnlySpan x, Span destination) => TensorPrimitives.Sigmoid(x, destination); + protected override void Sinh(ReadOnlySpan x, Span destination) => TensorPrimitives.Sinh(x, destination); + protected override float Sinh(float x) => MathF.Sinh(x); + protected override void SoftMax(ReadOnlySpan x, Span destination) => TensorPrimitives.SoftMax(x, destination); + protected override float Sqrt(float x) => MathF.Sqrt(x); + protected override void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => TensorPrimitives.Subtract(x, y, destination); + protected override void Subtract(ReadOnlySpan x, float y, Span destination) => TensorPrimitives.Subtract(x, y, destination); + protected override float Subtract(float x, float y) => x - y; + protected override float Sum(ReadOnlySpan x) => TensorPrimitives.Sum(x); + protected override float SumOfMagnitudes(ReadOnlySpan x) => TensorPrimitives.SumOfMagnitudes(x); + protected override float SumOfSquares(ReadOnlySpan x) => TensorPrimitives.SumOfSquares(x); + protected override void Tanh(ReadOnlySpan x, Span destination) => TensorPrimitives.Tanh(x, destination); + protected override float Tanh(float x) => MathF.Tanh(x); + + protected override float ConvertFromSingle(float f) => f; + + protected override float NaN => float.NaN; + protected override float NegativeZero => -0f; + protected override float Zero => 0f; + protected override float One => 1f; + protected override float NegativeOne => -1f; + protected override float MinValue => float.MinValue; + + protected override IEnumerable<(int Length, float Element)> VectorLengthAndIteratedRange(float min, float max, float increment) + { + foreach (int length in new[] { 4, 8, 16 }) + { + for (float f = min; f <= max; f += increment) + { + yield return (length, f); + } + } + } + + protected override float NextRandom() => (float)((Random.NextDouble() * 2) - 1); // For testing purposes, get a mix of negative and positive values. + + protected override void AssertEqualTolerance(float expected, float actual) => AssertEqualTolerance(expected, actual, 0.0001f); + + protected override void AssertEqualTolerance(float expected, float actual, float tolerance) + { + double diff = Math.Abs((double)expected - (double)actual); + if (diff > tolerance && diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance) + { + throw EqualException.ForMismatchedValues(expected, actual); + } + } + + protected override IEnumerable GetSpecialValues() + { + // NaN + yield return UInt32ToSingle(0xFFC0_0000); // -qNaN / float.NaN + yield return UInt32ToSingle(0xFFFF_FFFF); // -qNaN / all-bits-set + yield return UInt32ToSingle(0x7FC0_0000); // +qNaN + yield return UInt32ToSingle(0xFFA0_0000); // -sNaN + yield return UInt32ToSingle(0x7FA0_0000); // +sNaN + + // +Infinity, -Infinity + yield return float.PositiveInfinity; + yield return float.NegativeInfinity; + + // +Zero, -Zero + yield return +0.0f; + yield return -0.0f; + + // Subnormals + yield return +float.Epsilon; + yield return -float.Epsilon; + yield return UInt32ToSingle(0x007F_FFFF); + yield return UInt32ToSingle(0x807F_FFFF); + + // Normals + yield return UInt32ToSingle(0x0080_0000); + yield return UInt32ToSingle(0x8080_0000); + yield return UInt32ToSingle(0x7F7F_FFFF); // MaxValue + yield return UInt32ToSingle(0xFF7F_FFFF); // MinValue + } + + protected override void SetSpecialValues(Span x, Span y) + { + int pos; + + // NaNs + pos = Random.Next(x.Length); + x[pos] = float.NaN; + y[pos] = UInt32ToSingle(0x7FC0_0000); + + // +Infinity, -Infinity + pos = Random.Next(x.Length); + x[pos] = float.PositiveInfinity; + y[pos] = float.NegativeInfinity; + + // +Zero, -Zero + pos = Random.Next(x.Length); + x[pos] = +0.0f; + y[pos] = -0.0f; + + // +Epsilon, -Epsilon + pos = Random.Next(x.Length); + x[pos] = +float.Epsilon; + y[pos] = -float.Epsilon; + + // Same magnitude, opposite sign + pos = Random.Next(x.Length); + x[pos] = +5.0f; + y[pos] = -5.0f; + } + + private static unsafe float UInt32ToSingle(uint i) => *(float*)&i; + + // TODO: Move these IndexOf tests to the base class once generic versions are implemented. + #region IndexOfMax + [Fact] + public void IndexOfMax_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMax(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMax() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + } + }); + } + + [Fact] + public void IndexOfMax_FirstNaNReturned() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + } + }); + } + + [Fact] + public void IndexOfMax_Negative0LesserThanPositive0() + { + Assert.Equal(1, TensorPrimitives.IndexOfMax([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f])); + Assert.Equal(4, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); + } + #endregion + + #region IndexOfMaxMagnitude + [Fact] + public void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMaxMagnitude() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + } + }); + } + + [Fact] + public void IndexOfMaxMagnitude_FirstNaNReturned() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + } + }); + } + + [Fact] + public void IndexOfMaxMagnitude_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); + } + #endregion + + #region IndexOfMin + [Fact] + public void IndexOfMin_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMin(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMin() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)) - 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + } + }); + } + + [Fact] + public void IndexOfMin_FirstNaNReturned() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + } + }); + } + + [Fact] + public void IndexOfMin_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMin([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); + } + #endregion + + #region IndexOfMinMagnitude + [Fact] + public void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitude(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMinMagnitude() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + for (int i = 0; i < x.Length; i++) + { + x[i] = i % 2 == 0 ? 42 : -42; + } + + x[expected] = -41; + + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + } + }); + } + + [Fact] + public void IndexOfMinMagnitude_FirstNaNReturned() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + } + }); + } + + [Fact] + public void IndexOfMinMagnitude_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); + } + #endregion + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 847fa34a4470cd..9083d7d8bf8962 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -6,111 +6,141 @@ using System.Linq; using System.Runtime.InteropServices; using Xunit; -using Xunit.Sdk; - -#pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 namespace System.Numerics.Tensors.Tests { - public static partial class TensorPrimitivesTests + public abstract class TensorPrimitivesTests where T : unmanaged, IEquatable { + #region Abstract Methods Under Test + protected abstract void Abs(ReadOnlySpan x, Span destination); + protected abstract void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract void Add(ReadOnlySpan x, T y, Span destination); + protected abstract void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination); + protected abstract void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, T z, Span destination); + protected abstract void AddMultiply(ReadOnlySpan x, T y, ReadOnlySpan z, Span destination); + protected abstract void Cosh(ReadOnlySpan x, Span destination); + protected abstract T CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y); + protected abstract T Distance(ReadOnlySpan x, ReadOnlySpan y); + protected abstract void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract void Divide(ReadOnlySpan x, T y, Span destination); + protected abstract T Dot(ReadOnlySpan x, ReadOnlySpan y); + protected abstract void Exp(ReadOnlySpan x, Span destination); + protected abstract void Log(ReadOnlySpan x, Span destination); + protected abstract void Log2(ReadOnlySpan x, Span destination); + protected abstract T Max(ReadOnlySpan x); + protected abstract void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract T Max(T x, T y); + protected abstract T MaxMagnitude(ReadOnlySpan x); + protected abstract void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract T MaxMagnitude(T x, T y); + protected abstract T Min(ReadOnlySpan x); + protected abstract void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract T Min(T x, T y); + protected abstract T MinMagnitude(ReadOnlySpan x); + protected abstract void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract T MinMagnitude(T x, T y); + protected abstract void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract void Multiply(ReadOnlySpan x, T y, Span destination); + protected abstract void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination); + protected abstract void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, T z, Span destination); + protected abstract void MultiplyAdd(ReadOnlySpan x, T y, ReadOnlySpan z, Span destination); + protected abstract void Negate(ReadOnlySpan x, Span destination); + protected abstract T Norm(ReadOnlySpan x); + protected abstract T Product(ReadOnlySpan x); + protected abstract T ProductOfSums(ReadOnlySpan x, ReadOnlySpan y); + protected abstract T ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y); + protected abstract void Sigmoid(ReadOnlySpan x, Span destination); + protected abstract void Sinh(ReadOnlySpan x, Span destination); + protected abstract void SoftMax(ReadOnlySpan x, Span destination); + protected abstract void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination); + protected abstract void Subtract(ReadOnlySpan x, T y, Span destination); + protected abstract T Sum(ReadOnlySpan x); + protected abstract T SumOfMagnitudes(ReadOnlySpan x); + protected abstract T SumOfSquares(ReadOnlySpan x); + protected abstract void Tanh(ReadOnlySpan x, Span destination); + #endregion + + #region Abstract Validation + protected abstract T Abs(T x); + protected abstract T Add(T x, T y); + protected abstract T AddMultiply(T x, T y, T z); + protected abstract T Cosh(T x); + protected abstract T Divide(T x, T y); + protected abstract T Exp(T x); + protected abstract T Log(T x); + protected abstract T Log2(T x); + protected abstract T Multiply(T x, T y); + protected abstract T Sinh(T x); + protected abstract T Sqrt(T x); + protected abstract T Subtract(T x, T y); + protected abstract T Tanh(T x); + + protected abstract T NaN { get; } + protected abstract T NegativeZero { get; } + protected abstract T Zero { get; } + protected abstract T One { get; } + protected abstract T NegativeOne { get; } + protected abstract T MinValue { get; } + #endregion + #region Test Utilities - public static IEnumerable TensorLengthsIncluding0 => - TensorLengths.Concat(new object[][] { [0] }); + protected virtual bool IsFloatingPoint => typeof(T) == typeof(float) || typeof(T) == typeof(double); - public static IEnumerable TensorLengths => - from length in Enumerable.Range(1, 256) - select new object[] { length }; + protected abstract T ConvertFromSingle(float f); - public static IEnumerable VectorLengthAndIteratedRange(float min, float max, float increment) - { - foreach (int length in new[] { 4, 8, 16 }) - { - for (float f = min; f <= max; f += increment) - { - yield return new object[] { length, f }; - } - } - } + protected abstract IEnumerable GetSpecialValues(); + + /// + /// Loads a variety of special values (e.g. NaN) into random positions in + /// and related values into the corresponding positions in . + /// + protected abstract void SetSpecialValues(Span x, Span y); - private static readonly Random s_random = new Random(20230828); + protected abstract T NextRandom(); - private static BoundedMemory CreateTensor(int size) => BoundedMemory.Allocate(size); + protected abstract void AssertEqualTolerance(T expected, T actual); - private static BoundedMemory CreateAndFillTensor(int size) + protected abstract void AssertEqualTolerance(T expected, T actual, T tolerance); + + protected abstract IEnumerable<(int Length, T Element)> VectorLengthAndIteratedRange(T min, T max, T increment); + + protected Random Random { get; } = new Random(42); + + protected BoundedMemory CreateTensor(int size) => BoundedMemory.Allocate(size); + + public BoundedMemory CreateAndFillTensor(int size) { - BoundedMemory tensor = CreateTensor(size); - FillTensor(tensor.Span); + BoundedMemory tensor = CreateTensor(size); + FillTensor(tensor); return tensor; } - private static void FillTensor(Span tensor) + protected void FillTensor(Span span) { - for (int i = 0; i < tensor.Length; i++) + for (int i = 0; i < span.Length; i++) { - tensor[i] = NextSingle(); + span[i] = NextRandom(); } } - private static float NextSingle() => - // For testing purposes, get a mix of negative and positive values. - (float)((s_random.NextDouble() * 2) - 1); - - private static void AssertEqualTolerance(double expected, double actual, double tolerance = 0.00001f) + protected void FillTensor(Span span, T avoid) { - double diff = Math.Abs(expected - actual); - if (diff > tolerance && - diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance) + for (int i = 0; i < span.Length; i++) { - throw EqualException.ForMismatchedValues(expected, actual); + span[i] = NextRandom(avoid); } } - private static unsafe float MathFMaxMagnitude(float x, float y) - { - float ax = MathF.Abs(x), ay = MathF.Abs(y); - return (ax > ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x >= 0) ? x : y; - } - - private static unsafe float MathFMinMagnitude(float x, float y) + protected T NextRandom(T avoid) { - float ax = MathF.Abs(x), ay = MathF.Abs(y); - return (ax < ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x < 0) ? x : y; - } - - private static unsafe float UInt32ToSingle(uint i) => *(float*)&i; - - private static unsafe float SingleToUInt32(float f) => *(uint*)&f; - - /// Gets a variety of special values (e.g. NaN). - private static IEnumerable GetSpecialValues() - { - // NaN - yield return UInt32ToSingle(0xFFC0_0000); // -qNaN / float.NaN - yield return UInt32ToSingle(0xFFFF_FFFF); // -qNaN / all-bits-set - yield return UInt32ToSingle(0x7FC0_0000); // +qNaN - yield return UInt32ToSingle(0xFFA0_0000); // -sNaN - yield return UInt32ToSingle(0x7FA0_0000); // +sNaN - - // +Infinity, -Infinity - yield return float.PositiveInfinity; - yield return float.NegativeInfinity; - - // +Zero, -Zero - yield return +0.0f; - yield return -0.0f; - - // Subnormals - yield return +float.Epsilon; - yield return -float.Epsilon; - yield return UInt32ToSingle(0x007F_FFFF); - yield return UInt32ToSingle(0x807F_FFFF); - - // Normals - yield return UInt32ToSingle(0x0080_0000); - yield return UInt32ToSingle(0x8080_0000); - yield return UInt32ToSingle(0x7F7F_FFFF); // MaxValue - yield return UInt32ToSingle(0xFF7F_FFFF); // MinValue + while (true) + { + T value = NextRandom(); + if (!value.Equals(avoid)) + { + return value; + } + } } /// @@ -118,12 +148,12 @@ private static IEnumerable GetSpecialValues() /// the value is stored into a random position in , and the original /// value is subsequently restored. /// - private static void RunForEachSpecialValue(Action action, BoundedMemory x) + protected void RunForEachSpecialValue(Action action, BoundedMemory x) { - foreach (float value in GetSpecialValues()) + foreach (T value in GetSpecialValues()) { - int pos = s_random.Next(x.Length); - float orig = x[pos]; + int pos = Random.Next(x.Length); + T orig = x[pos]; x[pos] = value; action(); @@ -131,2861 +161,2881 @@ private static void RunForEachSpecialValue(Action action, BoundedMemory x x[pos] = orig; } } - - /// - /// Loads a variety of special values (e.g. NaN) into random positions in - /// and related values into the corresponding positions in . - /// - private static void SetSpecialValues(Span x, Span y) - { - int pos; - - // NaNs - pos = s_random.Next(x.Length); - x[pos] = float.NaN; - y[pos] = UInt32ToSingle(0x7FC0_0000); - - // +Infinity, -Infinity - pos = s_random.Next(x.Length); - x[pos] = float.PositiveInfinity; - y[pos] = float.NegativeInfinity; - - // +Zero, -Zero - pos = s_random.Next(x.Length); - x[pos] = +0.0f; - y[pos] = -0.0f; - - // +Epsilon, -Epsilon - pos = s_random.Next(x.Length); - x[pos] = +float.Epsilon; - y[pos] = -float.Epsilon; - - // Same magnitude, opposite sign - pos = s_random.Next(x.Length); - x[pos] = +5.0f; - y[pos] = -5.0f; - } #endregion #region Abs - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Abs(int tensorLength) + [Fact] + public void Abs_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + FillTensor(x, MinValue); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Abs(x, destination); + Abs(x, destination); - for (int i = 0; i < x.Length; i++) - { - AssertEqualTolerance(MathF.Abs(x[i]), destination[i]); - } + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(Abs(x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Abs_InPlace(int tensorLength) + [Fact] + public void Abs_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + FillTensor(x, MinValue); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Abs(x, x); + Abs(x, x); - for (int i = 0; i < x.Length; i++) - { - AssertEqualTolerance(MathF.Abs(xOrig[i]), x[i]); - } + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(Abs(xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Abs_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Abs_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); + AssertExtensions.Throws("destination", () => Abs(x, destination)); + }); } [Fact] - public static void Abs_ThrowsForOverlapppingInputsWithOutputs() + public void Abs_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); + AssertExtensions.Throws("destination", () => Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); } #endregion #region Add - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Add_TwoTensors(int tensorLength) + [Fact] + public void Add_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(x[i] + y[i], destination[i]); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Add(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(x[i], y[i]), destination[i]); + } - // Validate that the destination can be the same as an input. - TensorPrimitives.Add(x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); - } + T[] xOrig = x.Span.ToArray(); + + // Validate that the destination can be the same as an input. + Add(x, x, x); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(xOrig[i], xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Add_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Add_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Add(x, x, x); + Add(x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(xOrig[i], xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Add_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Add_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Add(y, x, destination)); + Assert.Throws(() => Add(x, y, destination)); + Assert.Throws(() => Add(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Add_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Add_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => Add(x, y, destination)); + }); } [Fact] - public static void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Add_TensorScalar(int tensorLength) + [Fact] + public void Add_TensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Add(x, y, destination); + Add(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] + y, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(x[i], y), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Add_TensorScalar_InPlace(int tensorLength) + [Fact] + public void Add_TensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(); - TensorPrimitives.Add(x, y, x); + Add(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] + y, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(xOrig[i], y), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Add_TensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => Add(x, y, destination)); + }); } [Fact] - public static void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), default(T), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Add(array.AsSpan(1, 2), default(T), array.AsSpan(2, 2))); } #endregion #region AddMultiply - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_ThreeTensors(int tensorLength) + [Fact] + public void AddMultiply_ThreeTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.AddMultiply(x, y, multiplier, destination); + AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] + y[i]) * multiplier[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(x[i], y[i], multiplier[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_ThreeTensors_InPlace(int tensorLength) + [Fact] + public void AddMultiply_ThreeTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.AddMultiply(x, x, x, x); + AddMultiply(x, x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] + xOrig[i]) * xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(xOrig[i], xOrig[i], xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, z, y, destination)); - Assert.Throws(() => TensorPrimitives.AddMultiply(z, x, y, destination)); + Assert.Throws(() => AddMultiply(x, y, z, destination)); + Assert.Throws(() => AddMultiply(x, z, y, destination)); + Assert.Throws(() => AddMultiply(z, x, y, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void AddMultiply_ThreeTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => AddMultiply(x, y, multiplier, destination)); + }); } [Fact] - public static void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + public void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_TensorTensorScalar(int tensorLength) + [Fact] + public void AddMultiply_TensorTensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float multiplier = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T multiplier = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.AddMultiply(x, y, multiplier, destination); + AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] + y[i]) * multiplier, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(x[i], y[i], multiplier), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength) + [Fact] + public void AddMultiply_TensorTensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float multiplier = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T multiplier = NextRandom(); - TensorPrimitives.AddMultiply(x, x, multiplier, x); + AddMultiply(x, x, multiplier, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] + xOrig[i]) * multiplier, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(xOrig[i], xOrig[i], multiplier), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) + [Fact] + public void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - float multiplier = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + T multiplier = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); - Assert.Throws(() => TensorPrimitives.AddMultiply(y, x, multiplier, destination)); + Assert.Throws(() => AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => AddMultiply(y, x, multiplier, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float multiplier = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T multiplier = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => AddMultiply(x, y, multiplier, destination)); + }); } [Fact] - public static void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(5, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_TensorScalarTensor(int tensorLength) + [Fact] + public void AddMultiply_TensorScalarTensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.AddMultiply(x, y, multiplier, destination); + AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] + y) * multiplier[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(x[i], y, multiplier[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength) + [Fact] + public void AddMultiply_TensorScalarTensor_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(); - TensorPrimitives.AddMultiply(x, y, x, x); + AddMultiply(x, y, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] + y) * xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(AddMultiply(xOrig[i], y, xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) + [Fact] + public void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); - Assert.Throws(() => TensorPrimitives.AddMultiply(z, y, x, destination)); + Assert.Throws(() => AddMultiply(x, y, z, destination)); + Assert.Throws(() => AddMultiply(z, y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => AddMultiply(x, y, multiplier, destination)); + }); } [Fact] - public static void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + public void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => AddMultiply(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region Cosh - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Cosh(int tensorLength) + [Fact] + public void Cosh_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Cosh(x, destination); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Cosh_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Cosh(x, x); + Cosh(x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Cosh(xOrig[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Cosh(x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Cosh_SpecialValues(int tensorLength) + [Fact] + public void Cosh_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - TensorPrimitives.Cosh(x, destination); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + + Cosh(x, x); + for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); + AssertEqualTolerance(Cosh(xOrig[i]), x[i]); } - }, x); + }); } - [Theory] - [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] - public static void Cosh_ValueRange(int vectorLength, float element) + [Fact] + public void Cosh_SpecialValues() { - float[] x = new float[vectorLength]; - float[] dest = new float[vectorLength]; + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + Cosh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Cosh(x[i]), destination[i]); + } + }, x); + }); + } - x.AsSpan().Fill(element); - TensorPrimitives.Cosh(x, dest); + [Fact] + public void Cosh_ValueRange() + { + if (!IsFloatingPoint) return; - float expected = MathF.Cosh(element); - foreach (float actual in dest) + Assert.All(VectorLengthAndIteratedRange(ConvertFromSingle(-100f), ConvertFromSingle(100f), ConvertFromSingle(3f)), arg => { - AssertEqualTolerance(expected, actual); - } + T[] x = new T[arg.Length]; + T[] dest = new T[arg.Length]; + + x.AsSpan().Fill(arg.Element); + Cosh(x, dest); + + T expected = Cosh(arg.Element); + foreach (T actual in dest) + { + AssertEqualTolerance(expected, actual); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Cosh_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Cosh_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); + AssertExtensions.Throws("destination", () => Cosh(x, destination)); + }); } [Fact] - public static void Cosh_ThrowsForOverlapppingInputsWithOutputs() + public void Cosh_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region CosineSimilarity - [Theory] - [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void CosineSimilarity_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(y, x)); + Assert.Throws(() => CosineSimilarity(x, y)); + Assert.Throws(() => CosineSimilarity(y, x)); + }); } [Fact] - public static void CosineSimilarity_ThrowsForEmpty() + public void CosineSimilarity_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); - } + if (!IsFloatingPoint) return; - [Theory] - [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] - [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] - public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) - { - AssertEqualTolerance(expectedResult, TensorPrimitives.CosineSimilarity(x, y)); + Assert.Throws(() => CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity(int tensorLength) + [Fact] + public void CosineSimilarity_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + if (!IsFloatingPoint) return; - float dot = 0f, squareX = 0f, squareY = 0f; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - dot += x[i] * y[i]; - squareX += x[i] * x[i]; - squareY += y[i] * y[i]; - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + T dot = default, squareX = default, squareY = default; + for (int i = 0; i < x.Length; i++) + { + dot = Add(dot, Multiply(x[i], y[i])); + squareX = Add(squareX, Multiply(x[i], x[i])); + squareY = Add(squareY, Multiply(y[i], y[i])); + } - AssertEqualTolerance(dot / (MathF.Sqrt(squareX) * MathF.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y)); + AssertEqualTolerance(Divide(dot, Multiply(Sqrt(squareX), Sqrt(squareY))), CosineSimilarity(x, y)); + }); } #endregion #region Distance [Fact] - public static void Distance_ThrowsForEmpty() + public void Distance_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); + if (!IsFloatingPoint) return; + + Assert.Throws(() => Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => Distance(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => Distance(CreateTensor(1), ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Distance_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Distance_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + if (!IsFloatingPoint) return; - Assert.Throws(() => TensorPrimitives.Distance(x, y)); - Assert.Throws(() => TensorPrimitives.Distance(y, x)); - } + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - [Theory] - [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] - [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] - [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.19615f)] - [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] - public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) - { - AssertEqualTolerance(expectedResult, TensorPrimitives.Distance(x, y)); + Assert.Throws(() => Distance(x, y)); + Assert.Throws(() => Distance(y, x)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Distance(int tensorLength) + [Fact] + public void Distance_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + if (!IsFloatingPoint) return; - float distance = 0f; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - distance += (x[i] - y[i]) * (x[i] - y[i]); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + T distance = default; + for (int i = 0; i < x.Length; i++) + { + distance = Add(distance, Multiply(Subtract(x[i], y[i]), Subtract(x[i], y[i]))); + } - AssertEqualTolerance(MathF.Sqrt(distance), TensorPrimitives.Distance(x, y)); + AssertEqualTolerance(Sqrt(distance), Distance(x, y)); + }); } #endregion #region Divide - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Divide_TwoTensors(int tensorLength) + [Fact] + public void Divide_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateTensor(tensorLength); + FillTensor(y, Zero); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Divide(x, y, destination); + Divide(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] / y[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Divide_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Divide_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + FillTensor(x, Zero); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Divide(x, x, x); + Divide(x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] / xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(xOrig[i], xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Divide_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Divide(y, x, destination)); + Assert.Throws(() => Divide(x, y, destination)); + Assert.Throws(() => Divide(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Divide_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Divide_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => Divide(x, y, destination)); + }); } [Fact] - public static void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Divide_TensorScalar(int tensorLength) + [Fact] + public void Divide_TensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(default); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Divide(x, y, destination); + Divide(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] / y, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(x[i], y), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Divide_TensorScalar_InPlace(int tensorLength) + [Fact] + public void Divide_TensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(default); - TensorPrimitives.Divide(x, y, x); + Divide(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] / y, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(xOrig[i], y), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Divide_TensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => Divide(x, y, destination)); + }); } [Fact] - public static void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region Dot - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) + [Fact] + public void Dot_ThrowsForMismatchedLengths_x_y() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - - Assert.Throws(() => TensorPrimitives.Dot(x, y)); - Assert.Throws(() => TensorPrimitives.Dot(y, x)); - } + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - [Theory] - [InlineData(new float[] { 1, 3, -5 }, new float[] { 4, -2, -1 }, 3)] - [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 32)] - [InlineData(new float[] { 1, 2, 3, 10, 8 }, new float[] { 4, 5, 6, -2, 7 }, 68)] - [InlineData(new float[] { }, new float[] { }, 0)] - public static void Dot_KnownValues(float[] x, float[] y, float expectedResult) - { - AssertEqualTolerance(expectedResult, TensorPrimitives.Dot(x, y)); + Assert.Throws(() => Dot(x, y)); + Assert.Throws(() => Dot(y, x)); + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Dot(int tensorLength) + [Fact] + public void Dot_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - - float dot = 0f; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - dot += x[i] * y[i]; - } - - AssertEqualTolerance(dot, TensorPrimitives.Dot(x, y)); - } - #endregion + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - #region Exp - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Exp(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Exp(x, destination); + T dot = default; + for (int i = 0; i < x.Length; i++) + { + dot = Add(dot, Multiply(x[i], y[i])); + } - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); - } + AssertEqualTolerance(dot, Dot(x, y)); + }); } + #endregion - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Exp_InPlace(int tensorLength) + #region Exp + [Fact] + public void Exp_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - - TensorPrimitives.Exp(x, x); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(MathF.Exp(xOrig[i]), x[i]); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Exp_SpecialValues(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Exp(x, destination); - RunForEachSpecialValue(() => - { - TensorPrimitives.Exp(x, destination); for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + AssertEqualTolerance(Exp(x[i]), destination[i]); } - }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Exp_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Exp_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + if (!IsFloatingPoint) return; - AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); - } + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - [Fact] - public static void Exp_ThrowsForOverlapppingInputsWithOutputs() - { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); - } - #endregion + Exp(x, x); - #region IndexOfMax - [Fact] - public static void IndexOfMax_ReturnsNegative1OnEmpty() - { - Assert.Equal(-1, TensorPrimitives.IndexOfMax(ReadOnlySpan.Empty)); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Exp(xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMax(int tensorLength) + [Fact] + public void Exp_SpecialValues() { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)) + 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); - } - } + if (!IsFloatingPoint) return; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMax_FirstNaNReturned(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + Assert.All(Helpers.TensorLengths, tensorLength => { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = float.NaN; - x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Fact] - public static void IndexOfMax_Negative0LesserThanPositive0() - { - Assert.Equal(1, TensorPrimitives.IndexOfMax([-0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f])); - Assert.Equal(4, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f, +0f, +0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMax([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); - Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); + RunForEachSpecialValue(() => + { + Exp(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Exp(x[i]), destination[i]); + } + }, x); + }); } - #endregion - #region IndexOfMaxMagnitude [Fact] - public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() + public void Exp_ThrowsForTooShortDestination() { - Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); - } + if (!IsFloatingPoint) return; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + Assert.All(Helpers.TensorLengths, tensorLength => { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = float.NaN; - x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); - } + AssertExtensions.Throws("destination", () => Exp(x, destination)); + }); } [Fact] - public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() + public void Exp_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); - Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion - #region IndexOfMin + #region Log [Fact] - public static void IndexOfMin_ReturnsNegative1OnEmpty() + public void Log_AllValues() { - Assert.Equal(-1, TensorPrimitives.IndexOfMin(ReadOnlySpan.Empty)); - } + if (!IsFloatingPoint) return; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMin(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)) - 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMin_FirstNaNReturned(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = float.NaN; - x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); - } - } + Log(x, destination); - [Fact] - public static void IndexOfMin_Negative0LesserThanPositive0() - { - Assert.Equal(0, TensorPrimitives.IndexOfMin([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f, -0f, -0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Log(x[i]), destination[i]); + } + }); } - #endregion - #region IndexOfMinMagnitude [Fact] - public static void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() + public void Log_InPlace() { - Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitude(ReadOnlySpan.Empty)); - } + if (!IsFloatingPoint) return; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMinMagnitude(int tensorLength) - { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - using BoundedMemory x = CreateTensor(tensorLength); - for (int i = 0; i < x.Length; i++) - { - x[i] = i % 2 == 0 ? 42 : -42; - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - x[expected] = -41; + Log(x, x); - Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Log(xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMinMagnitude_FirstNaNReturned(int tensorLength) + [Fact] + public void Log_SpecialValues() { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = float.NaN; - x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Fact] - public static void IndexOfMinMagnitude_Negative0LesserThanPositive0() - { - Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, -0f, -0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f, -0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); + RunForEachSpecialValue(() => + { + Log(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Log(x[i]), destination[i]); + } + }, x); + }); } - #endregion - #region Log - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Log(int tensorLength) + [Fact] + public void Log_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - TensorPrimitives.Log(x, destination); - - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(MathF.Log(x[i]), destination[i]); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => Log(x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Log_InPlace(int tensorLength) + [Fact] + public void Log_ThrowsForOverlapppingInputsWithOutputs() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - - TensorPrimitives.Log(x, x); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Log(xOrig[i]), x[i]); - } + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Log_SpecialValues(int tensorLength) + #region Log2 + [Fact] + public void Log2_AllValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - TensorPrimitives.Log(x, destination); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + Log2(x, destination); + for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + AssertEqualTolerance(Log2(x[i]), destination[i]); } - }, x); - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Log_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); + }); } [Fact] - public static void Log_ThrowsForOverlapppingInputsWithOutputs() - { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); - } - #endregion - - #region Log2 - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Log2(int tensorLength) + public void Log2_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Log2(x, destination); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Log2_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Log2(x, x); + Log2(x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Log(xOrig[i], 2), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Log2(xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Log2_SpecialValues(int tensorLength) + [Fact] + public void Log2_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - TensorPrimitives.Log2(x, destination); - for (int i = 0; i < tensorLength; i++) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => { - AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); - } - }, x); + Log2(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Log2(x[i]), destination[i]); + } + }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Log2_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Log2_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); + AssertExtensions.Throws("destination", () => Log2(x, destination)); + }); } [Fact] - public static void Log2_ThrowsForOverlapppingInputsWithOutputs() + public void Log2_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region Max [Fact] - public static void Max_Tensor_ThrowsForEmpty() + public void Max_Tensor_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.Max(ReadOnlySpan.Empty)); + Assert.Throws(() => Max(ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_Tensor(int tensorLength) + [Fact] + public void Max_Tensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Max(x)); + Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)), Max(x)); - float max = float.NegativeInfinity; - foreach (float f in x.Span) - { - max = Math.Max(max, f); - } + T max = x.Span[0]; + foreach (T f in x.Span) + { + max = Max(max, f); + } + + Assert.Equal(max, Max(x)); - Assert.Equal(max, TensorPrimitives.Max(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMax(x)]), SingleToUInt32(Max(x))); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_Tensor_SpecialValues(int tensorLength) + [Fact] + public void Max_Tensor_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - float max = float.NegativeInfinity; - foreach (float f in x.Span) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => { - max = Math.Max(max, f); - } + T max = x.Span[0]; + foreach (T f in x.Span) + { + max = Max(max, f); + } - Assert.Equal(max, TensorPrimitives.Max(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); - }, x); + Assert.Equal(max, Max(x)); + + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMax(x)]), SingleToUInt32(Max(x))); + }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_Tensor_NanReturned(int tensorLength) + [Fact] + public void Max_Tensor_NanReturned() { - using BoundedMemory x = CreateTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => { - FillTensor(x); - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.Max(x)); - } + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = NaN; + Assert.Equal(NaN, Max(x)); + } + }); } [Fact] - public static void Max_Tensor_Negative0LesserThanPositive0() + public void Max_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(+0f, TensorPrimitives.Max([-0f, +0f])); - Assert.Equal(+0f, TensorPrimitives.Max([+0f, -0f])); - Assert.Equal(-0f, TensorPrimitives.Max([-1, -0f])); - Assert.Equal(1, TensorPrimitives.Max([-1, -0f, 1])); + Assert.Equal(ConvertFromSingle(+0f), Max([ConvertFromSingle(-0f), ConvertFromSingle(+0f)])); + Assert.Equal(ConvertFromSingle(+0f), Max([ConvertFromSingle(+0f), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(-0f), Max([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(1), Max([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1)])); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Max_TwoTensors(int tensorLength) + [Fact] + public void Max_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Max(x, y, destination); + Max(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Max(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Max_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Max_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); - TensorPrimitives.Max(x, y, x); + Max(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Max(xOrig[i], y[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Max(xOrig[i], y[i]), x[i]); + } - xOrig.AsSpan().CopyTo(x.Span); - yOrig.AsSpan().CopyTo(y.Span); + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); - TensorPrimitives.Max(x, y, y); + Max(x, y, y); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Max(x[i], yOrig[i]), y[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Max(x[i], yOrig[i]), y[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_TwoTensors_SpecialValues(int tensorLength) + [Fact] + public void Max_TwoTensors_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - SetSpecialValues(x, y); + SetSpecialValues(x, y); - TensorPrimitives.Max(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); - } + Max(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Max(x[i], y[i]), destination[i]); + } - TensorPrimitives.Max(y, x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Max(y[i], x[i]), destination[i]); - } + Max(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Max(y[i], x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Max_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Max(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Max(y, x, destination)); + Assert.Throws(() => Max(x, y, destination)); + Assert.Throws(() => Max(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Max_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Max(x, y, destination)); + AssertExtensions.Throws("destination", () => Max(x, y, destination)); + }); } [Fact] - public static void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region MaxMagnitude [Fact] - public static void MaxMagnitude_Tensor_ThrowsForEmpty() + public void MaxMagnitude_Tensor_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.MaxMagnitude(ReadOnlySpan.Empty)); + Assert.Throws(() => MaxMagnitude(ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_Tensor(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MaxMagnitude_Tensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - float maxMagnitude = x[0]; - foreach (float f in x.Span) + Assert.All(Helpers.TensorLengths, tensorLength => { - maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + T maxMagnitude = x[0]; + foreach (T f in x.Span) + { + maxMagnitude = MaxMagnitude(maxMagnitude, f); + } - Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + Assert.Equal(maxMagnitude, MaxMagnitude(x)); + + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMaxMagnitude(x)]), SingleToUInt32(MaxMagnitude(x))); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_Tensor_SpecialValues(int tensorLength) + [Fact] + public void MaxMagnitude_Tensor_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - float maxMagnitude = x[0]; - foreach (float f in x.Span) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => { - maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); - } + T maxMagnitude = x[0]; + foreach (T f in x.Span) + { + maxMagnitude = MaxMagnitude(maxMagnitude, f); + } + + Assert.Equal(maxMagnitude, MaxMagnitude(x)); - Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); - }, x); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMaxMagnitude(x)]), SingleToUInt32(MaxMagnitude(x))); + }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_Tensor_NanReturned(int tensorLength) + [Fact] + public void MaxMagnitude_Tensor_NanReturned() { - using BoundedMemory x = CreateTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => { - FillTensor(x); - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); - } + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = ConvertFromSingle(float.NaN); + Assert.Equal(ConvertFromSingle(float.NaN), MaxMagnitude(x)); + } + }); } [Fact] - public static void MaxMagnitude_Tensor_Negative0LesserThanPositive0() + public void MaxMagnitude_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); - Assert.Equal(-1, TensorPrimitives.MaxMagnitude([-1, -0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); - Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); + Assert.Equal(ConvertFromSingle(0), MaxMagnitude([ConvertFromSingle(-0f), ConvertFromSingle(+0f)])); + Assert.Equal(ConvertFromSingle(0), MaxMagnitude([ConvertFromSingle(+0f), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(-1), MaxMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(1), MaxMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1)])); + Assert.Equal(ConvertFromSingle(0), MaxMagnitude([ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(0f)])); + Assert.Equal(ConvertFromSingle(1), MaxMagnitude( [ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-0f), ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(0f), ConvertFromSingle(1)])); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MaxMagnitude_TwoTensors(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MaxMagnitude_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MaxMagnitude(x, y, destination); + MaxMagnitude(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MaxMagnitude(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MaxMagnitude_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); - TensorPrimitives.MaxMagnitude(x, y, x); + MaxMagnitude(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMaxMagnitude(xOrig[i], y[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MaxMagnitude(xOrig[i], y[i]), x[i]); + } - xOrig.AsSpan().CopyTo(x.Span); - yOrig.AsSpan().CopyTo(y.Span); + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); - TensorPrimitives.MaxMagnitude(x, y, y); + MaxMagnitude(x, y, y); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMaxMagnitude(x[i], yOrig[i]), y[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MaxMagnitude(x[i], yOrig[i]), y[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MaxMagnitude_TwoTensors_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - SetSpecialValues(x, y); + SetSpecialValues(x, y); - TensorPrimitives.MaxMagnitude(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); - } + MaxMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MaxMagnitude(x[i], y[i]), destination[i]); + } - TensorPrimitives.MaxMagnitude(y, x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMaxMagnitude(y[i], x[i]), destination[i]); - } + MaxMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MaxMagnitude(y[i], x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.MaxMagnitude(x, y, destination)); - Assert.Throws(() => TensorPrimitives.MaxMagnitude(y, x, destination)); + Assert.Throws(() => MaxMagnitude(x, y, destination)); + Assert.Throws(() => MaxMagnitude(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); + AssertExtensions.Throws("destination", () => MaxMagnitude(x, y, destination)); + }); } [Fact] - public static void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region Min [Fact] - public static void Min_Tensor_ThrowsForEmpty() + public void Min_Tensor_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + Assert.Throws(() => Min(ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_Tensor(int tensorLength) + [Fact] + public void Min_Tensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); + Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), Min(x)); - float min = float.PositiveInfinity; - foreach (float f in x.Span) - { - min = Math.Min(min, f); - } + T min = ConvertFromSingle(float.PositiveInfinity); + foreach (T f in x.Span) + { + min = Min(min, f); + } + + Assert.Equal(min, Min(x)); - Assert.Equal(min, TensorPrimitives.Min(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMin(x)]), SingleToUInt32(Min(x))); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_Tensor_SpecialValues(int tensorLength) + [Fact] + public void Min_Tensor_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - float min = float.PositiveInfinity; - foreach (float f in x.Span) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => { - min = Math.Min(min, f); - } + T min = ConvertFromSingle(float.PositiveInfinity); + foreach (T f in x.Span) + { + min = Min(min, f); + } + + Assert.Equal(min, Min(x)); - Assert.Equal(min, TensorPrimitives.Min(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); - }, x); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMin(x)]), SingleToUInt32(Min(x))); + }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_Tensor_NanReturned(int tensorLength) + [Fact] + public void Min_Tensor_NanReturned() { - using BoundedMemory x = CreateTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => { - FillTensor(x); - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.Min(x)); - } + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = ConvertFromSingle(float.NaN); + Assert.Equal(ConvertFromSingle(float.NaN), Min(x)); + } + }); } [Fact] - public static void Min_Tensor_Negative0LesserThanPositive0() + public void Min_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); - Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + Assert.Equal(ConvertFromSingle(-0f), Min([ConvertFromSingle(-0f), ConvertFromSingle(+0f)])); + Assert.Equal(ConvertFromSingle(-0f), Min([ConvertFromSingle(+0f), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(-1), Min([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(-1), Min([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1)])); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Min_TwoTensors(int tensorLength) + [Fact] + public void Min_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Min(x, y, destination); + Min(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Min(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Min_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Min_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); - TensorPrimitives.Min(x, y, x); + Min(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Min(xOrig[i], y[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Min(xOrig[i], y[i]), x[i]); + } - xOrig.AsSpan().CopyTo(x.Span); - yOrig.AsSpan().CopyTo(y.Span); + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); - TensorPrimitives.Min(x, y, y); + Min(x, y, y); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Min(x[i], yOrig[i]), y[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Min(x[i], yOrig[i]), y[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_TwoTensors_SpecialValues(int tensorLength) + [Fact] + public void Min_TwoTensors_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - SetSpecialValues(x, y); + SetSpecialValues(x, y); - TensorPrimitives.Min(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); - } + Min(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Min(x[i], y[i]), destination[i]); + } - TensorPrimitives.Min(y, x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Min(y[i], x[i]), destination[i]); - } + Min(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Min(y[i], x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Min_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Min(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Min(y, x, destination)); + Assert.Throws(() => Min(x, y, destination)); + Assert.Throws(() => Min(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Min_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); + AssertExtensions.Throws("destination", () => Min(x, y, destination)); + }); } [Fact] - public static void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region MinMagnitude [Fact] - public static void MinMagnitude_Tensor_ThrowsForEmpty() + public void MinMagnitude_Tensor_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); + Assert.Throws(() => MinMagnitude(ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_Tensor(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MinMagnitude_Tensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - float minMagnitude = x[0]; - foreach (float f in x.Span) + Assert.All(Helpers.TensorLengths, tensorLength => { - minMagnitude = MathFMinMagnitude(minMagnitude, f); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + T minMagnitude = x[0]; + foreach (T f in x.Span) + { + minMagnitude = MinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, MinMagnitude(x)); - Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMinMagnitude(x)]), SingleToUInt32(MinMagnitude(x))); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_Tensor_SpecialValues(int tensorLength) + [Fact] + public void MinMagnitude_Tensor_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - float minMagnitude = x[0]; - foreach (float f in x.Span) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => { - minMagnitude = MathFMinMagnitude(minMagnitude, f); - } + T minMagnitude = x[0]; + foreach (T f in x.Span) + { + minMagnitude = MinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, MinMagnitude(x)); - Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); - Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); - }, x); + // TODO: Put a variant of this back once we have IndexOf routines + // Assert.Equal(SingleToUInt32(x[IndexOfMinMagnitude(x)]), SingleToUInt32(MinMagnitude(x))); + }, x); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_Tensor_NanReturned(int tensorLength) + [Fact] + public void MinMagnitude_Tensor_NanReturned() { - using BoundedMemory x = CreateTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => { - FillTensor(x); - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); - } + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = ConvertFromSingle(float.NaN); + Assert.Equal(ConvertFromSingle(float.NaN), MinMagnitude(x)); + } + }); } [Fact] - public static void MinMagnitude_Tensor_Negative0LesserThanPositive0() + public void MinMagnitude_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + Assert.Equal(ConvertFromSingle(0), MinMagnitude([ConvertFromSingle(-0f), ConvertFromSingle(+0f)])); + Assert.Equal(ConvertFromSingle(0), MinMagnitude([ConvertFromSingle(+0f), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(0), MinMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); + Assert.Equal(ConvertFromSingle(0), MinMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1)])); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MinMagnitude_TwoTensors(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MinMagnitude_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MinMagnitude(x, y, destination); + MinMagnitude(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MinMagnitude(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MinMagnitude_TwoTensors_InPlace(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MinMagnitude_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); - TensorPrimitives.MinMagnitude(x, y, x); + MinMagnitude(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMinMagnitude(xOrig[i], y[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MinMagnitude(xOrig[i], y[i]), x[i]); + } - xOrig.AsSpan().CopyTo(x.Span); - yOrig.AsSpan().CopyTo(y.Span); + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); - TensorPrimitives.MinMagnitude(x, y, y); + MinMagnitude(x, y, y); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMinMagnitude(x[i], yOrig[i]), y[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MinMagnitude(x[i], yOrig[i]), y[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength) + [ActiveIssue("https://github.com/dotnet/runtime/issues/96443", TestRuntimes.Mono)] + [Fact] + public void MinMagnitude_TwoTensors_SpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - SetSpecialValues(x, y); + SetSpecialValues(x, y); - TensorPrimitives.MinMagnitude(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); - } + MinMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MinMagnitude(x[i], y[i]), destination[i]); + } - TensorPrimitives.MinMagnitude(y, x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathFMinMagnitude(y[i], x[i]), destination[i]); - } + MinMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MinMagnitude(y[i], x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.MinMagnitude(x, y, destination)); - Assert.Throws(() => TensorPrimitives.MinMagnitude(y, x, destination)); + Assert.Throws(() => MinMagnitude(x, y, destination)); + Assert.Throws(() => MinMagnitude(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void MinMagnitude_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); + AssertExtensions.Throws("destination", () => MinMagnitude(x, y, destination)); + }); } [Fact] - public static void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region Multiply - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Multiply_TwoTensors(int tensorLength) + [Fact] + public void Multiply_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Multiply(x, y, destination); + Multiply(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] * y[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Multiply_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Multiply_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Multiply(x, x, x); + Multiply(x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] * xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(xOrig[i], xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Multiply_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Multiply(y, x, destination)); + Assert.Throws(() => Multiply(x, y, destination)); + Assert.Throws(() => Multiply(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Multiply_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Multiply_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => Multiply(x, y, destination)); + }); } [Fact] - public static void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Multiply_TensorScalar(int tensorLength) + [Fact] + public void Multiply_TensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Multiply(x, y, destination); + Multiply(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] * y, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(x[i], y), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Multiply_TensorScalar_InPlace(int tensorLength) + [Fact] + public void Multiply_TensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(); - TensorPrimitives.Multiply(x, y, x); + Multiply(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] * y, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(xOrig[i], y), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Multiply_TensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => Multiply(x, y, destination)); + }); } [Fact] - public static void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), default(T), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Multiply(array.AsSpan(1, 2), default(T), array.AsSpan(2, 2))); } #endregion #region MultiplyAdd - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_ThreeTensors(int tensorLength) + [Fact] + public void MultiplyAdd_ThreeTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] * y[i]) + addend[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(x[i], y[i]), addend[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength) + [Fact] + public void MultiplyAdd_ThreeTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.MultiplyAdd(x, x, x, x); + MultiplyAdd(x, x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] * xOrig[i]) + xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(xOrig[i], xOrig[i]), xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength) + [Fact] + public void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, z, destination)); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, z, y, destination)); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(z, x, y, destination)); + Assert.Throws(() => MultiplyAdd(x, y, z, destination)); + Assert.Throws(() => MultiplyAdd(x, z, y, destination)); + Assert.Throws(() => MultiplyAdd(z, x, y, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => MultiplyAdd(x, y, addend, destination)); + }); } [Fact] - public static void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + public void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_TensorTensorScalar(int tensorLength) + [Fact] + public void MultiplyAdd_TensorTensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T addend = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] * y[i]) + addend, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(x[i], y[i]), addend), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength) + [Fact] + public void MultiplyAdd_TensorTensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float addend = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T addend = NextRandom(); - TensorPrimitives.MultiplyAdd(x, x, addend, x); + MultiplyAdd(x, x, addend, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] * xOrig[i]) + addend, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(xOrig[i], xOrig[i]), addend), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T addend = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => MultiplyAdd(x, y, addend, destination)); + }); } [Fact] - public static void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), default(T), array.AsSpan(5, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_TensorScalarTensor(int tensorLength) + [Fact] + public void MultiplyAdd_TensorScalarTensor() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((x[i] * y) + addend[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(x[i], y), addend[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength) + [Fact] + public void MultiplyAdd_TensorScalarTensor_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(); - TensorPrimitives.MultiplyAdd(x, y, x, x); + MultiplyAdd(x, y, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance((xOrig[i] * y) + xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Add(Multiply(xOrig[i], y), xOrig[i]), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => MultiplyAdd(x, y, addend, destination)); + }); } [Fact] - public static void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + public void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => MultiplyAdd(array.AsSpan(1, 2), default(T), array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion #region Negate - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Negate(int tensorLength) + [Fact] + public void Negate_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Negate(x, destination); + Negate(x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(-x[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(x[i], NegativeOne), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Negate_InPlace(int tensorLength) + [Fact] + public void Negate_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Negate(x, x); + Negate(x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(-xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Multiply(xOrig[i], NegativeOne), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Negate_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Negate_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + AssertExtensions.Throws("destination", () => Negate(x, destination)); + }); } [Fact] - public static void Negate_ThrowsForOverlapppingInputsWithOutputs() + public void Negate_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region Norm - [Theory] - [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] - [InlineData(new float[] { 3, 4 }, 5)] - [InlineData(new float[] { 3 }, 3)] - [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] - [InlineData(new float[] { }, 0f)] - public static void Norm_KnownValues(float[] x, float expectedResult) - { - AssertEqualTolerance(expectedResult, TensorPrimitives.Norm(x)); - } - - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Norm(int tensorLength) + [Fact] + public void Norm_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); + if (!IsFloatingPoint) return; - float sumOfSquares = 0f; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - sumOfSquares += x[i] * x[i]; - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); - AssertEqualTolerance(MathF.Sqrt(sumOfSquares), TensorPrimitives.Norm(x)); + T sumOfSquares = Zero; + for (int i = 0; i < x.Length; i++) + { + sumOfSquares = Add(sumOfSquares, Multiply(x[i], x[i])); + } + + AssertEqualTolerance(Sqrt(sumOfSquares), Norm(x)); + }); } #endregion #region Product [Fact] - public static void Product_ThrowsForEmpty() + public void Product_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + Assert.Throws(() => Product(ReadOnlySpan.Empty)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Product(int tensorLength) + [Fact] + public void Product_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - float f = x[0]; - for (int i = 1; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - f *= x[i]; - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); - AssertEqualTolerance(f, TensorPrimitives.Product(x)); - } + T f = x[0]; + for (int i = 1; i < x.Length; i++) + { + f = Multiply(f, x[i]); + } - [Theory] - [InlineData(1, new float[] { 1 })] - [InlineData(-2, new float[] { 1, -2 })] - [InlineData(-6, new float[] { 1, -2, 3 })] - [InlineData(24, new float[] { 1, -2, 3, -4 })] - [InlineData(120, new float[] { 1, -2, 3, -4, 5 })] - [InlineData(-720, new float[] { 1, -2, 3, -4, 5, -6 })] - [InlineData(0, new float[] { 1, -2, 3, -4, 5, -6, 0 })] - [InlineData(0, new float[] { 0, 1, -2, 3, -4, 5, -6 })] - [InlineData(0, new float[] { 1, -2, 3, 0, -4, 5, -6 })] - [InlineData(float.NaN, new float[] { 1, -2, 3, float.NaN, -4, 5, -6 })] - public static void Product_KnownValues(float expected, float[] input) - { - Assert.Equal(expected, TensorPrimitives.Product(input)); + AssertEqualTolerance(f, Product(x)); + }); } #endregion #region ProductOfDifferences [Fact] - public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() + public void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() { - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); + Assert.Throws(() => ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => ProductOfDifferences(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => ProductOfDifferences(CreateTensor(43), CreateTensor(44))); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ProductOfDifferences(int tensorLength) + [Fact] + public void ProductOfDifferences_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - - float f = x[0] - y[0]; - for (int i = 1; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - f *= x[i] - y[i]; - } - AssertEqualTolerance(f, TensorPrimitives.ProductOfDifferences(x, y)); - } - - [Theory] - [InlineData(0, new float[] {0 }, new float[] {0})] - [InlineData(0, new float[] {1 }, new float[] {1})] - [InlineData(1, new float[] {1 }, new float[] {0})] - [InlineData(-1, new float[] {0 }, new float[] {1})] - [InlineData(-1, new float[] {1, 2, 3, 4, 5 }, new float[] {2, 3, 4, 5, 6})] - [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] {0, 0, 0, 0, 0})] - [InlineData(-120, new float[] {0, 0, 0, 0, 0 }, new float[] {1, 2, 3, 4, 5})] - [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] {0, 0, 0, 0, 0})] - public static void ProductOfDifferences_KnownValues(float expected, float[] x, float[] y) - { - Assert.Equal(expected, TensorPrimitives.ProductOfDifferences(x, y)); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + T f = Subtract(x[0], y[0]); + for (int i = 1; i < x.Length; i++) + { + f = Multiply(f, Subtract(x[i], y[i])); + } + AssertEqualTolerance(f, ProductOfDifferences(x, y)); + }); } #endregion #region ProductOfSums [Fact] - public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + public void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() { - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + Assert.Throws(() => ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => ProductOfSums(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => ProductOfSums(CreateTensor(43), CreateTensor(44))); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ProductOfSums(int tensorLength) + [Fact] + public void ProductOfSums_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - - float f = x[0] + y[0]; - for (int i = 1; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - f *= x[i] + y[i]; - } - AssertEqualTolerance(f, TensorPrimitives.ProductOfSums(x, y)); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - [Theory] - [InlineData(0, new float[] {0 }, new float[] { 0 })] - [InlineData(1, new float[] {0 }, new float[] { 1 })] - [InlineData(1, new float[] {1 }, new float[] { 0 })] - [InlineData(2, new float[] {1 }, new float[] { 1 })] - [InlineData(10395, new float[] {1, 2, 3, 4, 5 }, new float[] { 2, 3, 4, 5, 6 })] - [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] - [InlineData(120, new float[] {0, 0, 0, 0, 0 }, new float[] { 1, 2, 3, 4, 5 })] - [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] - public static void ProductOfSums_KnownValues(float expected, float[] x, float[] y) - { - Assert.Equal(expected, TensorPrimitives.ProductOfSums(x, y)); + T f = Add(x[0], y[0]); + for (int i = 1; i < x.Length; i++) + { + f = Multiply(f, Add(x[i], y[i])); + } + AssertEqualTolerance(f, ProductOfSums(x, y)); + }); } #endregion #region Sigmoid - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sigmoid(int tensorLength) + [Fact] + public void Sigmoid_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Sigmoid(x, destination); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sigmoid_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Sigmoid(x, x); + Sigmoid(x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(1f / (1f + MathF.Exp(-xOrig[i])), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(One, Add(One, Exp(Multiply(x[i], NegativeOne)))), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sigmoid_SpecialValues(int tensorLength) + [Fact] + public void Sigmoid_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengths, tensorLength => { - TensorPrimitives.Sigmoid(x, destination); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + + Sigmoid(x, x); + for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + AssertEqualTolerance(Divide(One, Add(One, Exp(Multiply(xOrig[i], NegativeOne)))), x[i]); } - }, x); + }); } - [Theory] - [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] - [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] - [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] - public static void Sigmoid_KnownValues(float[] x, float[] expectedResult) + [Fact] + public void Sigmoid_SpecialValues() { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.Sigmoid(x, dest); + if (!IsFloatingPoint) return; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + Sigmoid(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(One, Add(One, Exp(Multiply(x[i], NegativeOne)))), destination[i]); + } + }, x); + }); } - [Theory] - [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] - public static void Sigmoid_DestinationLongerThanSource(float[] x, float[] expectedResult) + [Fact] + public void Sigmoid_ThrowsForTooShortDestination() { - using BoundedMemory dest = CreateTensor(x.Length + 1); - - TensorPrimitives.Sigmoid(x, dest); + if (!IsFloatingPoint) return; - float originalLast = dest[dest.Length - 1]; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); - } - Assert.Equal(originalLast, dest[dest.Length - 1]); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); + AssertExtensions.Throws("destination", () => Sigmoid(x, destination)); + }); } [Fact] - public static void Sigmoid_ThrowsForEmptyInput() + public void Sigmoid_ThrowsForEmptyInput() { - AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); + if (!IsFloatingPoint) return; + + AssertExtensions.Throws(() => Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); } [Fact] - public static void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() + public void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region Sinh - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Sinh(int tensorLength) + [Fact] + public void Sinh_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Sinh(x, destination); + if (!IsFloatingPoint) return; - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Sinh_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - - TensorPrimitives.Sinh(x, x); + Sinh(x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Sinh(xOrig[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Sinh(x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sinh_SpecialValues(int tensorLength) + [Fact] + public void Sinh_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - TensorPrimitives.Sinh(x, destination); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + + Sinh(x, x); + for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + AssertEqualTolerance(Sinh(xOrig[i]), x[i]); } - }, x); + }); } - [Theory] - [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] - public static void Sinh_ValueRange(int vectorLengths, float element) + [Fact] + public void Sinh_SpecialValues() { - float[] x = new float[vectorLengths]; - float[] dest = new float[vectorLengths]; + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + Sinh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Sinh(x[i]), destination[i]); + } + }, x); + }); + } - x.AsSpan().Fill(element); - TensorPrimitives.Sinh(x, dest); + [Fact] + public void Sinh_ValueRange() + { + if (!IsFloatingPoint) return; - float expected = MathF.Sinh(element); - foreach (float actual in dest) + Assert.All(VectorLengthAndIteratedRange(ConvertFromSingle(-100f), ConvertFromSingle(100f), ConvertFromSingle(3f)), args => { - AssertEqualTolerance(expected, actual); - } + T[] x = new T[args.Length]; + T[] dest = new T[args.Length]; + + x.AsSpan().Fill(args.Element); + Sinh(x, dest); + + T expected = Sinh(args.Element); + foreach (T actual in dest) + { + AssertEqualTolerance(expected, actual); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sinh_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Sinh_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + if (!IsFloatingPoint) return; - AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => Sinh(x, destination)); + }); } [Fact] - public static void Sinh_ThrowsForOverlapppingInputsWithOutputs() + public void Sinh_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region SoftMax - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SoftMax(int tensorLength) + [Fact] + public void SoftMax_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - TensorPrimitives.SoftMax(x, destination); - - float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp); - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(MathF.Exp(x[i]) / expSum, destination[i]); - } - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SoftMax_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + SoftMax(x, destination); - TensorPrimitives.SoftMax(x, x); + T expSum = Zero; + foreach (T value in x.Memory.Span) + { + expSum = Add(expSum, Exp(value)); + } - float expSum = xOrig.Sum(MathF.Exp); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Exp(xOrig[i]) / expSum, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(Exp(x[i]), expSum), destination[i]); + } + }); } - [Theory] - [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] - [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] - [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] - [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })] - public static void SoftMax_KnownValues(float[] x, float[] expectedResult) + [Fact] + public void SoftMax_InPlace() { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.SoftMax(x, dest); + if (!IsFloatingPoint) return; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + + SoftMax(x, x); + + T expSum = Zero; + foreach (T value in xOrig) + { + expSum = Add(expSum, Exp(value)); + } + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Divide(Exp(xOrig[i]), expSum), x[i]); + } + }); } [Fact] - public static void SoftMax_DestinationLongerThanSource() + public void SoftMax_ThrowsForTooShortDestination() { - float[] x = [3, 1, .2f]; - float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; - using BoundedMemory dest = CreateTensor(x.Length + 1); - TensorPrimitives.SoftMax(x, dest); + if (!IsFloatingPoint) return; - for (int i = 0; i < x.Length; i++) + Assert.All(Helpers.TensorLengths, tensorLength => { - AssertEqualTolerance(expectedResult[i], dest[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); + AssertExtensions.Throws("destination", () => SoftMax(x, destination)); + }); } [Fact] - public static void SoftMax_ThrowsForEmptyInput() + public void SoftMax_ThrowsForEmptyInput() { - AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + if (!IsFloatingPoint) return; + + AssertExtensions.Throws(() => SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); } [Fact] - public static void SoftMax_ThrowsForOverlapppingInputsWithOutputs() + public void SoftMax_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion #region Subtract - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Subtract_TwoTensors(int tensorLength) + [Fact] + public void Subtract_TwoTensors() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Subtract(x, y, destination); + Subtract(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] - y[i], destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Subtract(x[i], y[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Subtract_TwoTensors_InPlace(int tensorLength) + [Fact] + public void Subtract_TwoTensors_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); - TensorPrimitives.Subtract(x, x, x); + Subtract(x, x, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] - xOrig[i], x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Zero, x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [Fact] + public void Subtract_TwoTensors_ThrowsForMismatchedLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); - Assert.Throws(() => TensorPrimitives.Subtract(y, x, destination)); + Assert.Throws(() => Subtract(x, y, destination)); + Assert.Throws(() => Subtract(y, x, destination)); + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Subtract_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Subtract_TwoTensors_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => Subtract(x, y, destination)); + }); } [Fact] - public static void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + public void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Subtract_TensorScalar(int tensorLength) + [Fact] + public void Subtract_TensorScalar() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Subtract(x, y, destination); + Subtract(x, y, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(x[i] - y, destination[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Subtract(x[i], y), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Subtract_TensorScalar_InPlace(int tensorLength) + [Fact] + public void Subtract_TensorScalar_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); - float y = NextSingle(); + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + T y = NextRandom(); - TensorPrimitives.Subtract(x, y, x); + Subtract(x, y, x); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(xOrig[i] - y, x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Subtract(xOrig[i], y), x[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Subtract_TensorScalar_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T y = NextRandom(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => Subtract(x, y, destination)); + }); } [Fact] - public static void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + public void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), default(T), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Subtract(array.AsSpan(1, 2), default(T), array.AsSpan(2, 2))); } #endregion #region Sum - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Sum(int tensorLength) + [Fact] + public void Sum_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - AssertEqualTolerance(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x)); - - float sum = 0; - foreach (float f in x.Span) + Assert.All(Helpers.TensorLengths, tensorLength => { - sum += f; - } - AssertEqualTolerance(sum, TensorPrimitives.Sum(x)); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); - [Theory] - [InlineData(0, new float[] { 0 })] - [InlineData(1, new float[] { 0, 1 })] - [InlineData(6, new float[] { 1, 2, 3 })] - [InlineData(0, new float[] { -3, 0, 3 })] - [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] - public static void Sum_KnownValues(float expected, float[] x) - { - Assert.Equal(expected, TensorPrimitives.Sum(x)); + T sum = Zero; + foreach (T value in x.Memory.Span) + { + sum = Add(sum, value); + } + AssertEqualTolerance(sum, Sum(x)); + }); } #endregion #region SumOfMagnitudes - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SumOfMagnitudes(int tensorLength) + [Fact] + public void SumOfMagnitudes_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x)); - - float sum = 0; - foreach (float f in x.Span) + Assert.All(Helpers.TensorLengths, tensorLength => { - sum += MathF.Abs(f); - } - AssertEqualTolerance(sum, TensorPrimitives.SumOfMagnitudes(x)); - } + using BoundedMemory x = CreateTensor(tensorLength); + FillTensor(x, MinValue); - [Theory] - [InlineData(0, new float[] { 0 })] - [InlineData(1, new float[] { 0, 1 })] - [InlineData(6, new float[] { 1, 2, 3 })] - [InlineData(6, new float[] { -3, 0, 3 })] - [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] - public static void SumOfMagnitudes_KnownValues(float expected, float[] x) - { - Assert.Equal(expected, TensorPrimitives.SumOfMagnitudes(x)); + T sum = Zero; + foreach (T value in x.Memory.Span) + { + sum = Add(sum, Abs(value)); + } + AssertEqualTolerance(sum, SumOfMagnitudes(x)); + }); } #endregion #region SumOfSquares - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SumOfSquares(int tensorLength) + [Fact] + public void SumOfSquares_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - - AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x)); - - float sum = 0; - foreach (float f in x.Span) + Assert.All(Helpers.TensorLengths, tensorLength => { - sum += f * f; - } - AssertEqualTolerance(sum, TensorPrimitives.SumOfSquares(x)); - } + using BoundedMemory x = CreateAndFillTensor(tensorLength); - [Theory] - [InlineData(0, new float[] { 0 })] - [InlineData(1, new float[] { 0, 1 })] - [InlineData(14, new float[] { 1, 2, 3 })] - [InlineData(18, new float[] { -3, 0, 3 })] - [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] - public static void SumOfSquares_KnownValues(float expected, float[] x) - { - Assert.Equal(expected, TensorPrimitives.SumOfSquares(x)); + T sum = Zero; + foreach (T value in x.Memory.Span) + { + sum = Add(sum, Multiply(value, value)); + } + AssertEqualTolerance(sum, SumOfSquares(x)); + }); } #endregion #region Tanh - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Tanh(int tensorLength) + [Fact] + public void Tanh_AllLengths() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - TensorPrimitives.Tanh(x, destination); - - for (int i = 0; i < tensorLength; i++) + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Tanh_InPlace(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float[] xOrig = x.Span.ToArray(); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Tanh(x, x); + Tanh(x, destination); - for (int i = 0; i < tensorLength; i++) - { - AssertEqualTolerance(MathF.Tanh(xOrig[i]), x[i]); - } + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Tanh(x[i]), destination[i]); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Tanh_SpecialValues(int tensorLength) + [Fact] + public void Tanh_InPlace() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + if (!IsFloatingPoint) return; - RunForEachSpecialValue(() => + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => { - TensorPrimitives.Tanh(x, destination); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + T[] xOrig = x.Span.ToArray(); + + Tanh(x, x); + for (int i = 0; i < tensorLength; i++) { - AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + AssertEqualTolerance(Tanh(xOrig[i]), x[i]); } - }, x); + }); } - [Theory] - [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -11f, 11f, 0.2f })] - public static void Tanh_ValueRange(int vectorLengths, float element) + [Fact] + public void Tanh_SpecialValues() { - float[] x = new float[vectorLengths]; - float[] dest = new float[vectorLengths]; + if (!IsFloatingPoint) return; - x.AsSpan().Fill(element); - TensorPrimitives.Tanh(x, dest); + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + Tanh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(Tanh(x[i]), destination[i]); + } + }, x); + }); + } - float expected = MathF.Tanh(element); - foreach (float actual in dest) + [Fact] + public void Tanh_ValueRange() + { + if (!IsFloatingPoint) return; + + Assert.All(VectorLengthAndIteratedRange(ConvertFromSingle(-11f), ConvertFromSingle(11f), ConvertFromSingle(0.2f)), args => { - AssertEqualTolerance(expected, actual); - } + T[] x = new T[args.Length]; + T[] dest = new T[args.Length]; + + x.AsSpan().Fill(args.Element); + Tanh(x, dest); + + T expected = Tanh(args.Element); + foreach (T actual in dest) + { + AssertEqualTolerance(expected, actual); + } + }); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void Tanh_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public void Tanh_ThrowsForTooShortDestination() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); + AssertExtensions.Throws("destination", () => Tanh(x, destination)); + }); } [Fact] - public static void Tanh_ThrowsForOverlapppingInputsWithOutputs() + public void Tanh_ThrowsForOverlapppingInputsWithOutputs() { - float[] array = new float[10]; - AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); - AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + if (!IsFloatingPoint) return; + + T[] array = new T[10]; + AssertExtensions.Throws("destination", () => Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } #endregion } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs deleted file mode 100644 index 06ab341db16242..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs +++ /dev/null @@ -1,142 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers; -using Xunit; - -namespace System.Numerics.Tensors.Tests -{ - public static partial class TensorPrimitivesTests - { - #region ConvertToHalf - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void ConvertToHalf(int tensorLength) - { - using BoundedMemory source = CreateAndFillTensor(tensorLength); - foreach (int destLength in new[] { source.Length, source.Length + 1 }) - { - using BoundedMemory destination = BoundedMemory.Allocate(destLength); - destination.Span.Fill(Half.Zero); - - TensorPrimitives.ConvertToHalf(source, destination); - - for (int i = 0; i < source.Length; i++) - { - Assert.Equal((Half)source[i], destination[i]); - } - - if (destination.Length > source.Length) - { - for (int i = source.Length; i < destination.Length; i++) - { - Assert.Equal(Half.Zero, destination[i]); - } - } - } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ConvertToHalf_SpecialValues(int tensorLength) - { - using BoundedMemory source = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); - - // NaN, infinities, and 0s - source[s_random.Next(source.Length)] = float.NaN; - source[s_random.Next(source.Length)] = float.PositiveInfinity; - source[s_random.Next(source.Length)] = float.NegativeInfinity; - source[s_random.Next(source.Length)] = 0; - source[s_random.Next(source.Length)] = float.NegativeZero; - - TensorPrimitives.ConvertToHalf(source, destination); - - for (int i = 0; i < source.Length; i++) - { - Assert.Equal((Half)source[i], destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory source = CreateAndFillTensor(tensorLength); - Half[] destination = new Half[source.Length - 1]; - - AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToHalf(source, destination)); - } - #endregion - - #region ConvertToSingle - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void ConvertToSingle(int tensorLength) - { - using BoundedMemory source = BoundedMemory.Allocate(tensorLength); - for (int i = 0; i < source.Length; i++) - { - source[i] = (Half)s_random.NextSingle(); - } - - foreach (int destLength in new[] { source.Length, source.Length + 1 }) - { - using BoundedMemory destination = CreateTensor(destLength); - destination.Span.Fill(0f); - - TensorPrimitives.ConvertToSingle(source, destination); - - for (int i = 0; i < source.Length; i++) - { - Assert.Equal((float)source[i], destination[i]); - } - - if (destination.Length > source.Length) - { - for (int i = source.Length; i < destination.Length; i++) - { - Assert.Equal(0f, destination[i]); - } - } - } - } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ConvertToSingle_SpecialValues(int tensorLength) - { - using BoundedMemory source = BoundedMemory.Allocate(tensorLength); - for (int i = 0; i < source.Length; i++) - { - source[i] = (Half)s_random.NextSingle(); - } - - using BoundedMemory destination = CreateTensor(tensorLength); - - // NaN, infinities, and 0s - source[s_random.Next(source.Length)] = Half.NaN; - source[s_random.Next(source.Length)] = Half.PositiveInfinity; - source[s_random.Next(source.Length)] = Half.NegativeInfinity; - source[s_random.Next(source.Length)] = Half.Zero; - source[s_random.Next(source.Length)] = Half.NegativeZero; - - TensorPrimitives.ConvertToSingle(source, destination); - - for (int i = 0; i < source.Length; i++) - { - Assert.Equal((float)source[i], destination[i]); - } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ConvertToSingle_ThrowsForTooShortDestination(int tensorLength) - { - Half[] source = new Half[tensorLength]; - using BoundedMemory destination = CreateTensor(source.Length - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToSingle(source, destination)); - } - #endregion - } -}