Skip to content
Merged
12 changes: 10 additions & 2 deletions src/libraries/Common/tests/Tests/System/StringTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4697,14 +4697,22 @@ public static void Remove_Invalid()
[InlineData("Aaaaaaaa", 'A', 'a', "aaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
// Single iteration of vectorised path; 3 remainders handled by vectorized path
[InlineData("AaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorized path; 0 remainders handled by vectorized path
[InlineData("aaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// Eight chars before a match (copyLength > 0), single iteration of vectorized path for the remainder
[InlineData("12345678AAAAAAA", 'A', 'a', "12345678aaaaaaa")]
// ------------------------- For Vector<ushort>.Count == 16 (AVX2) -------------------------
[InlineData("AaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
// Single iteration of vectorised path; 3 remainders handled by vectorized path
[InlineData("AaaaaaaaAaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorized path; 0 remainders handled by vectorized path
[InlineData("aaaaaaaaaaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Sixteen chars before a match (copyLength > 0), single iteration of vectorized path for the remainder
[InlineData("1234567890123456AAAAAAAAAAAAAAA", 'A', 'a', "1234567890123456aaaaaaaaaaaaaaa")]
// ----------------------------------- General test data -----------------------------------
[InlineData("Hello", 'l', '!', "He!!o")] // 2 match, non-vectorised path
[InlineData("Hello", 'e', 'e', "Hello")] // oldChar and newChar are same; nothing to replace
Expand Down
16 changes: 16 additions & 0 deletions src/libraries/System.Private.CoreLib/src/System/Numerics/Vector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,14 @@ public static bool LessThanOrEqualAll<T>(Vector<T> left, Vector<T> right)
public static bool LessThanOrEqualAny<T>(Vector<T> left, Vector<T> right)
where T : struct => LessThanOrEqual(left, right).As<T, nuint>() != Vector<nuint>.Zero;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Vector<T> LoadUnsafe<T>(ref T source, nuint elementOffset)
where T : struct
{
source = ref Unsafe.Add(ref source, elementOffset);
return Unsafe.ReadUnaligned<Vector<T>>(ref Unsafe.As<T, byte>(ref source));
}

/// <summary>Computes the maximum of two vectors on a per-element basis.</summary>
/// <param name="left">The vector to compare with <paramref name="right" />.</param>
/// <param name="right">The vector to compare with <paramref name="left" />.</param>
Expand Down Expand Up @@ -1658,6 +1666,14 @@ public static Vector<T> SquareRoot<T>(Vector<T> value)
return result;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static void StoreUnsafe<T>(this Vector<T> source, ref T destination, nuint elementOffset)
where T : struct
{
destination = ref Unsafe.Add(ref destination, elementOffset);
Unsafe.WriteUnaligned(ref Unsafe.As<T, byte>(ref destination), source);
}

/// <summary>Subtracts two vectors to compute their difference.</summary>
/// <param name="left">The vector from which <paramref name="right" /> will be subtracted.</param>
/// <param name="right">The vector to subtract from <paramref name="left" />.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ public string Replace(char oldChar, char newChar)
if (firstIndex < 0)
return this;

int remainingLength = Length - firstIndex;
nuint remainingLength = (uint)(Length - firstIndex);
string result = FastAllocateString(Length);

int copyLength = firstIndex;
Expand All @@ -1006,35 +1006,56 @@ public string Replace(char oldChar, char newChar)
}

// Copy the remaining characters, doing the replacement as we go.
ref ushort pSrc = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref _firstChar), copyLength);
ref ushort pDst = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref result._firstChar), copyLength);
ref ushort pSrc = ref Unsafe.Add(ref GetRawStringDataAsUInt16(), (uint)copyLength);
ref ushort pDst = ref Unsafe.Add(ref result.GetRawStringDataAsUInt16(), (uint)copyLength);
nuint i = 0;

if (Vector.IsHardwareAccelerated && remainingLength >= Vector<ushort>.Count)
if (Vector.IsHardwareAccelerated && Length >= Vector<ushort>.Count)
{
Vector<ushort> oldChars = new Vector<ushort>(oldChar);
Vector<ushort> newChars = new Vector<ushort>(newChar);
Vector<ushort> oldChars = new(oldChar);
Vector<ushort> newChars = new(newChar);

do
Vector<ushort> original;
Vector<ushort> equals;
Vector<ushort> results;

if (remainingLength > (nuint)Vector<ushort>.Count)
{
Vector<ushort> original = Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<ushort, byte>(ref pSrc));
Vector<ushort> equals = Vector.Equals(original, oldChars);
Vector<ushort> results = Vector.ConditionalSelect(equals, newChars, original);
Unsafe.WriteUnaligned(ref Unsafe.As<ushort, byte>(ref pDst), results);

pSrc = ref Unsafe.Add(ref pSrc, Vector<ushort>.Count);
pDst = ref Unsafe.Add(ref pDst, Vector<ushort>.Count);
remainingLength -= Vector<ushort>.Count;
nuint lengthToExamine = remainingLength - (nuint)Vector<ushort>.Count;

do
{
original = Vector.LoadUnsafe(ref pSrc, i);
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
results.StoreUnsafe(ref pDst, i);

i += (nuint)Vector<ushort>.Count;
}
while (i < lengthToExamine);
}
while (remainingLength >= Vector<ushort>.Count);
}

for (; remainingLength > 0; remainingLength--)
{
ushort currentChar = pSrc;
pDst = currentChar == oldChar ? newChar : currentChar;
// There are [0, Vector<ushort>.Count) elements remaining now.
// As the operation is idempotent, and we know that in total there are at least Vector<ushort>.Count
// elements available, we read a vector from the very end of the string, perform the replace
// and write to the destination at the very end.
// Thus we can eliminate the scalar processing of the remaining elements.
// We perform this operation even if there are 0 elements remaining, as it is cheaper than the
// additional check which would introduce a branch here.

pSrc = ref Unsafe.Add(ref pSrc, 1);
pDst = ref Unsafe.Add(ref pDst, 1);
i = (uint)(Length - Vector<ushort>.Count);
original = Vector.LoadUnsafe(ref GetRawStringDataAsUInt16(), i);
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
results.StoreUnsafe(ref result.GetRawStringDataAsUInt16(), i);
}
else
{
for (; i < remainingLength; ++i)
{
ushort currentChar = Unsafe.Add(ref pSrc, i);
Unsafe.Add(ref pDst, i) = currentChar == oldChar ? newChar : currentChar;
}
}

return result;
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Private.CoreLib/src/System/String.cs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ public static bool IsNullOrWhiteSpace([NotNullWhen(false)] string? value)
public ref readonly char GetPinnableReference() => ref _firstChar;

internal ref char GetRawStringData() => ref _firstChar;
internal ref ushort GetRawStringDataAsUInt16() => ref Unsafe.As<char, ushort>(ref _firstChar);

// Helper for encodings so they can talk to our buffer directly
// stringLength must be the exact size we'll expect
Expand Down