diff --git a/src/libraries/System.Memory/ref/System.Memory.cs b/src/libraries/System.Memory/ref/System.Memory.cs index 8aac4764a1498a..8f6a5ac768e808 100644 --- a/src/libraries/System.Memory/ref/System.Memory.cs +++ b/src/libraries/System.Memory/ref/System.Memory.cs @@ -639,6 +639,24 @@ public static void WriteUIntPtrLittleEndian(System.Span destination, nuint } namespace System.Buffers.Text { + public static class Base64Url + { + public static OperationStatus EncodeToChars(System.ReadOnlySpan source, System.Span destination, out int bytesConsumed, out int charsWritten, bool isFinalBlock = true) { throw null; } + public static int EncodeToChars(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static bool TryEncodeToChars(System.ReadOnlySpan source, System.Span destination, out int charsWritten) { throw null; } + public static char[] EncodeToChars(System.ReadOnlySpan source) { throw null; } + public static string EncodeToString(System.ReadOnlySpan source) { throw null; } + public static byte[] EncodeToUtf8(System.ReadOnlySpan source) { throw null; } + public static int EncodeToUtf8(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static System.Buffers.OperationStatus EncodeToUtf8(System.ReadOnlySpan source, System.Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { throw null; } + public static int GetEncodedLength(int bytesLength) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64UrlText) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64UrlText, out int decodedLength) { throw null; } + public static bool IsValid(System.ReadOnlySpan utf8Base64UrlText) { throw null; } + public static bool IsValid(System.ReadOnlySpan utf8Base64UrlText, out int decodedLength) { throw null; } + public static bool TryEncodeToUtf8(System.ReadOnlySpan source, System.Span destination, out int charsWritten) { throw null; } + public static bool TryEncodeToUtf8InPlace(System.Span buffer, int dataLength, out int bytesWritten) { throw null; } + } public static partial class Utf8Formatter { public static bool TryFormat(bool value, System.Span destination, out int bytesWritten, System.Buffers.StandardFormat format = default(System.Buffers.StandardFormat)) { throw null; } diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index c7f164ad9b7f5f..ba08c79e53f226 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -267,7 +267,7 @@ public void InvalidSizeBytes(string utf8WithByteToBeIgnored) [InlineData("Y")] public void InvalidSizeChars(string utf8WithByteToBeIgnored) { - byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored; Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderTests.cs new file mode 100644 index 00000000000000..d3403d3ddac51b --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderTests.cs @@ -0,0 +1,302 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.SpanTests; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlEncoderTests + { + [Fact] + public void BasicEncodingAndDecoding() + { + var bytes = new byte[byte.MaxValue + 1]; + for (int i = 0; i < byte.MaxValue + 1; i++) + { + bytes[i] = (byte)i; + } + + for (int value = 0; value < 256; value++) + { + Span sourceBytes = bytes.AsSpan(0, value + 1); + Span encodedBytes = new byte[Base64Url.GetEncodedLength(sourceBytes.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(sourceBytes, encodedBytes, out int consumed, out int encodedBytesCount)); + Assert.Equal(sourceBytes.Length, consumed); + Assert.Equal(encodedBytes.Length, encodedBytesCount); + + string encodedText = Encoding.ASCII.GetString(encodedBytes.ToArray()); + string expectedText = Convert.ToBase64String(bytes, 0, value + 1).Replace('+', '-').Replace('/', '_').TrimEnd('='); + Assert.Equal(expectedText, encodedText); + + /*Assert.Equal(0, encodedBytes.Length % 4); + Span decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(encodedBytes.Length)]; + Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(encodedBytes, decodedBytes, out consumed, out int decodedByteCount)); + Assert.Equal(encodedBytes.Length, consumed); + Assert.Equal(sourceBytes.Length, decodedByteCount); + Assert.True(sourceBytes.SequenceEqual(decodedBytes.Slice(0, decodedByteCount)));*/ + } + } + + [Fact] + public void BasicEncoding() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + OperationStatus result = Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(source.Length, consumed); + Assert.Equal(encodedBytes.Length, encodedBytesCount); + Assert.True(VerifyEncodingCorrectness(source.Length, encodedBytes.Length, source, encodedBytes)); + } + } + + private static bool VerifyEncodingCorrectness(int expectedConsumed, int expectedWritten, Span source, Span encodedBytes) + { + string expectedText = Convert.ToBase64String(source.Slice(0, expectedConsumed).ToArray()).Replace('+', '-').Replace('/', '_').TrimEnd('='); + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + return expectedText.Equals(encodedText); + } + + [Fact] + public void BasicEncodingWithFinalBlockFalse() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + int expectedConsumed = source.Length / 3 * 3; // only consume closest multiple of three since isFinalBlock is false + int expectedWritten = source.Length / 3 * 4; + + // The constant random seed guarantees that both states are tested. + OperationStatus expectedStatus = numBytes % 3 == 0 ? OperationStatus.Done : OperationStatus.NeedMoreData; + Assert.Equal(expectedStatus, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + Assert.True(VerifyEncodingCorrectness(expectedConsumed, expectedWritten, source, encodedBytes)); + } + } + + [Theory] + [InlineData(1, "AQ")] + [InlineData(2, "AQI")] + [InlineData(3, "AQID")] + [InlineData(4, "AQIDBA")] + [InlineData(5, "AQIDBAU")] + [InlineData(6, "AQIDBAUG")] + [InlineData(7, "AQIDBAUGBw")] + public void BasicEncodingWithFinalBlockTrueKnownInput(int numBytes, string expectedText) + { + int expectedConsumed = numBytes; + int expectedWritten = expectedText.Length; + + Span source = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) + { + source[i] = (byte)(i + 1); + } + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + + [Theory] + [InlineData(1, "", 0, 0)] + [InlineData(2, "", 0, 0)] + [InlineData(3, "AQID", 3, 4)] + [InlineData(4, "AQID", 3, 4)] + [InlineData(5, "AQID", 3, 4)] + [InlineData(6, "AQIDBAUG", 6, 8)] + [InlineData(7, "AQIDBAUG", 6, 8)] + public void BasicEncodingWithFinalBlockFalseKnownInput(int numBytes, string expectedText, int expectedConsumed, int expectedWritten) + { + Span source = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) + { + source[i] = (byte)(i + 1); + } + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + OperationStatus expectedStatus = numBytes % 3 == 0 ? OperationStatus.Done : OperationStatus.NeedMoreData; + Assert.Equal(expectedStatus, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodeEmptySpan(bool isFinalBlock) + { + Span source = Span.Empty; + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock)); + Assert.Equal(0, consumed); + Assert.Equal(0, encodedBytesCount); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodingOutputTooSmall(bool isFinalBlock) + { + for (int numBytes = 4; numBytes < 20; numBytes++) + { + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[4]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int written, isFinalBlock)); + int expectedConsumed = 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(VerifyEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodingOutputTooSmallRetry(bool isFinalBlock) + { + Span source = new byte[750]; + Base64TestHelper.InitializeBytes(source); + + int outputSize = 320; + int requiredSize = Base64Url.GetEncodedLength(source.Length); + + Span encodedBytes = new byte[outputSize]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int written, isFinalBlock)); + int expectedConsumed = encodedBytes.Length / 4 * 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(VerifyEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + + encodedBytes = new byte[requiredSize - outputSize]; + source = source.Slice(consumed); + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out consumed, out written, isFinalBlock)); + expectedConsumed = encodedBytes.Length / 4 * 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(VerifyEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + [OuterLoop] + public void EncodeTooLargeSpan(bool isFinalBlock) + { + if (!Environment.Is64BitProcess) + return; + + bool allocatedFirst = false; + bool allocatedSecond = false; + IntPtr memBlockFirst = IntPtr.Zero; + IntPtr memBlockSecond = IntPtr.Zero; + + // int.MaxValue - (int.MaxValue % 4) => 2147483644, largest multiple of 4 less than int.MaxValue + // CLR default limit of 2 gigabytes (GB). + // 1610612734, larger than MaximumEncodeLength, requires output buffer of size 2147483648 (which is > int.MaxValue) + const int sourceCount = (int.MaxValue >> 2) * 3 + 1; + const int encodedCount = 2000000000; + + try + { + allocatedFirst = AllocationHelper.TryAllocNative((IntPtr)sourceCount, out memBlockFirst); + allocatedSecond = AllocationHelper.TryAllocNative((IntPtr)encodedCount, out memBlockSecond); + if (allocatedFirst && allocatedSecond) + { + unsafe + { + var source = new Span(memBlockFirst.ToPointer(), sourceCount); + var encodedBytes = new Span(memBlockSecond.ToPointer(), encodedCount); + + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock)); + Assert.Equal((encodedBytes.Length >> 2) * 3, consumed); // encoding 1500000000 bytes fits into buffer of 2000000000 bytes + Assert.Equal(encodedBytes.Length, encodedBytesCount); + } + } + } + finally + { + if (allocatedFirst) + AllocationHelper.ReleaseNative(ref memBlockFirst); + if (allocatedSecond) + AllocationHelper.ReleaseNative(ref memBlockSecond); + } + } + + [Fact] + public void GetEncodedLength() + { + // (int.MaxValue - 4)/(4/3) => 1610612733, otherwise integer overflow + int[] input = { 0, 1, 2, 3, 4, 5, 6, 1610612728, 1610612729, 1610612730, 1610612731, 1610612732, 1610612733 }; + int[] expected = { 0, 2, 3, 4, 6, 7, 8, 2147483638, 2147483639, 2147483640, 2147483642, 2147483643, 2147483644 }; + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(expected[i], Base64Url.GetEncodedLength(input[i])); + } + + // integer overflow + Assert.Throws(() => Base64Url.GetEncodedLength(1610612734)); + Assert.Throws(() => Base64Url.GetEncodedLength(int.MaxValue)); + + // negative input + Assert.Throws(() => Base64Url.GetEncodedLength(-1)); + Assert.Throws(() => Base64Url.GetEncodedLength(int.MinValue)); + } + + [Fact] + public void TryEncodeInPlace() + { + const int numberOfBytes = 15; + Span testBytes = new byte[numberOfBytes / 3 * 4]; // slack since encoding inflates the data + Base64TestHelper.InitializeBytes(testBytes); + + for (int numberOfBytesToTest = 0; numberOfBytesToTest <= numberOfBytes; numberOfBytesToTest++) + { + Span sliced = testBytes.Slice(0, numberOfBytesToTest); + Span dest = new byte[Base64Url.GetEncodedLength(numberOfBytesToTest)]; + //var expectedText = Convert.ToBase64String(testBytes.Slice(0, numberOfBytesToTest).ToArray()).Replace('+', '-').Replace('/', '_').TrimEnd('='); + var status = Base64Url.EncodeToUtf8(sliced, dest, out _, out int _); + Assert.Equal(OperationStatus.Done, status); + var expectedText = Encoding.ASCII.GetString(dest.ToArray()); + Assert.True(Base64Url.TryEncodeToUtf8InPlace(testBytes, numberOfBytesToTest, out int bytesWritten)); + Assert.Equal(Base64Url.GetEncodedLength(numberOfBytesToTest), bytesWritten); + + var encodedText = Encoding.ASCII.GetString(testBytes.Slice(0, bytesWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + } + + [Fact] + public void TryEncodeInPlaceOutputTooSmall() + { + byte[] testBytes = { 1, 2, 3 }; + + Assert.False(Base64Url.TryEncodeToUtf8InPlace(testBytes, testBytes.Length, out int bytesWritten)); + Assert.Equal(0, bytesWritten); + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs new file mode 100644 index 00000000000000..2ed3f41f9e39d6 --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs @@ -0,0 +1,368 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlValidationUnitTests : Base64TestBase + { + public static readonly byte[] s_encodingMap = { + 65, 66, 67, 68, 69, 70, 71, 72, //A..H + 73, 74, 75, 76, 77, 78, 79, 80, //I..P + 81, 82, 83, 84, 85, 86, 87, 88, //Q..X + 89, 90, 97, 98, 99, 100, 101, 102, //Y..Z, a..f + 103, 104, 105, 106, 107, 108, 109, 110, //g..n + 111, 112, 113, 114, 115, 116, 117, 118, //o..v + 119, 120, 121, 122, 48, 49, 50, 51, //w..z, 0..3 + 52, 53, 54, 55, 56, 57, 45, 95 //4..9, -, _ + }; + + private static void InitializeDecodableBytes(Span bytes, int seed = 100) + { + var rnd = new Random(seed); + for (int i = 0; i < bytes.Length; i++) + { + int index = (byte)rnd.Next(0, s_encodingMap.Length); + bytes[i] = s_encodingMap[index]; + } + } + + [Fact] + public void BasicValidationBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + InitializeDecodableBytes(source, numBytes); + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + InitializeDecodableBytes(source, numBytes); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64Url.IsValid(chars)); + Assert.True(Base64Url.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new byte[numBytes]; + + Assert.False(Base64Url.IsValid(source)); + Assert.False(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new char[numBytes]; + + Assert.False(Base64Url.IsValid(source)); + Assert.False(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void ValidateEmptySpanBytes() + { + Span source = Span.Empty; + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateEmptySpanChars() + { + Span source = Span.Empty; + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateGuidBytes() + { + Span source = new byte[22]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64Url.EncodeToUtf8(decodedBytes, source, out int _, out int _); + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Fact] + public void ValidateGuidChars() + { + Span source = new byte[22]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64Url.EncodeToUtf8(decodedBytes, source, out int _, out int _); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64Url.IsValid(chars)); + Assert.True(Base64Url.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + [InlineData("YQ%%", 1)] + [InlineData("YWI%", 2)] + [InlineData("YW% ", 1)] + public void ValidateWithPaddingReturnsCorrectCountBytes(string utf8WithByteToBeIgnored, int expectedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + [InlineData("YQ%%", 1)] + [InlineData("YWI%", 2)] + [InlineData("YW% ", 1)] + public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + public void DecodeEmptySpan(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YWJ", true, 2)] + [InlineData("YW", true, 1)] + [InlineData("Y", false, 0)] + public void SmallSizeBytes(string utf8Text, bool isValid, int expectedDecodedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8Text); + + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedDecodedLength, decodedLength); + } + + [Theory] + [InlineData("YWJ", true, 2)] + [InlineData("YW", true, 1)] + [InlineData("Y", false, 0)] + public void SmallSizeChars(string utf8Text, bool isValid, int expectedDecodedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8Text; + + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedDecodedLength, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData(" aYWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + [InlineData("YQ+a")] // plus invalid + [InlineData("/Qab")] // slash invalid + public void InvalidBase64UrlBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + [InlineData("a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + public void InvalidBase64UrlChars(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index 843d5e1b479ceb..1087e40071f094 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -13,6 +13,8 @@ + + diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index 22a4ddea6ed606..333c4281f87810 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -131,6 +131,9 @@ + + + diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs index 4b1f597e5e8b5b..f19586eb120f75 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs @@ -3,13 +3,14 @@ using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; namespace System.Buffers.Text { public static partial class Base64 { [Conditional("DEBUG")] - private static unsafe void AssertRead(byte* src, byte* srcStart, int srcLength) + internal static unsafe void AssertRead(byte* src, byte* srcStart, int srcLength) { int vectorElements = Unsafe.SizeOf(); byte* readEnd = src + vectorElements; @@ -23,7 +24,7 @@ private static unsafe void AssertRead(byte* src, byte* srcStart, int sr } [Conditional("DEBUG")] - private static unsafe void AssertWrite(byte* dest, byte* destStart, int destLength) + internal static unsafe void AssertWrite(byte* dest, byte* destStart, int destLength) { int vectorElements = Unsafe.SizeOf(); byte* writeEnd = dest + vectorElements; @@ -35,5 +36,18 @@ private static unsafe void AssertWrite(byte* dest, byte* destStart, int Debug.Fail($"Write for {typeof(TVector)} is not within safe bounds. destIndex: {destIndex}, destLength: {destLength}"); } } + + internal interface IBase64Encoder + { + static abstract int IncrementPadTwo { get; } + static abstract int IncrementPadOne { get; } + static abstract ReadOnlySpan EncodingMap { get; } + static abstract Vector256 Avx2Lut { get; } + static abstract Vector128 AdvSimdLut4 { get; } + static abstract Vector128 Ssse3AdvSimdLut { get; } + static abstract int GetMaxSrcLength(int srcLength, int destLength); + static abstract unsafe uint EncodeOneOptionallyPadTwo(byte* oneByte, ref byte encodingMap); + static abstract unsafe uint EncodeTwoOptionallyPadOne(byte* oneByte, ref byte encodingMap); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 1ad2cf9faa9f30..c9bbcfb6b77127 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -832,7 +832,7 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Ssse3))] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static Vector128 SimdShuffle(Vector128 left, Vector128 right, Vector128 mask8F) + internal static Vector128 SimdShuffle(Vector128 left, Vector128 right, Vector128 mask8F) { Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian); diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs index b63c711e410326..b8f9339782d17a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Runtime; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; @@ -35,7 +34,11 @@ public static partial class Base64 /// - NeedMoreData - only if is , otherwise the output is padded if the input is not a multiple of 3 /// It does not return InvalidData since that is not possible for base64 encoding. /// - public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span utf8, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span utf8, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + EncodeToUtf8(bytes, utf8, out bytesConsumed, out bytesWritten, isFinalBlock); + + internal static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span utf8, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + where TBase64Encoder : IBase64Encoder { if (bytes.IsEmpty) { @@ -49,16 +52,7 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span { int srcLength = bytes.Length; int destLength = utf8.Length; - int maxSrcLength; - - if (srcLength <= MaximumEncodeLength && destLength >= GetMaxEncodedToUtf8Length(srcLength)) - { - maxSrcLength = srcLength; - } - else - { - maxSrcLength = (destLength >> 2) * 3; - } + int maxSrcLength = TBase64Encoder.GetMaxSrcLength(srcLength, destLength); byte* src = srcBytes; byte* dest = destBytes; @@ -67,19 +61,19 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span if (maxSrcLength >= 16) { - byte* end = srcMax - 64; + byte* end = srcMax - 48; if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported && (end >= src)) { - Avx512Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx512Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; } - end = srcMax - 64; + end = srcMax - 32; if (Avx2.IsSupported && (end >= src)) { - Avx2Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx2Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; @@ -88,7 +82,7 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span end = srcMax - 48; if (AdvSimd.Arm64.IsSupported && (end >= src)) { - AdvSimdEncode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + AdvSimdEncode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; @@ -97,14 +91,14 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span end = srcMax - 16; if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src)) { - Vector128Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Vector128Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; } } - ref byte encodingMap = ref MemoryMarshal.GetReference(EncodingMap); + ref byte encodingMap = ref MemoryMarshal.GetReference(TBase64Encoder.EncodingMap); uint result = 0; srcMax -= 2; @@ -129,17 +123,17 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span if (src + 1 == srcEnd) { - result = EncodeAndPadTwo(src, ref encodingMap); + result = TBase64Encoder.EncodeOneOptionallyPadTwo(src, ref encodingMap); Unsafe.WriteUnaligned(dest, result); src += 1; - dest += 4; + dest += TBase64Encoder.IncrementPadTwo; } else if (src + 2 == srcEnd) { - result = EncodeAndPadOne(src, ref encodingMap); + result = TBase64Encoder.EncodeTwoOptionallyPadOne(src, ref encodingMap); Unsafe.WriteUnaligned(dest, result); src += 2; - dest += 4; + dest += TBase64Encoder.IncrementPadOne; } DoneExit: @@ -201,25 +195,28 @@ public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int { int encodedLength = GetMaxEncodedToUtf8Length(dataLength); if (buffer.Length < encodedLength) - goto FalseExit; + { + bytesWritten = 0; + return OperationStatus.DestinationTooSmall; + } int leftover = dataLength - (dataLength / 3) * 3; // how many bytes after packs of 3 uint destinationIndex = (uint)(encodedLength - 4); uint sourceIndex = (uint)(dataLength - leftover); uint result = 0; - ref byte encodingMap = ref MemoryMarshal.GetReference(EncodingMap); + ref byte encodingMap = ref MemoryMarshal.GetReference(Base64Encoder.EncodingMap); // encode last pack to avoid conditional in the main loop if (leftover != 0) { if (leftover == 1) { - result = EncodeAndPadTwo(bufferBytes + sourceIndex, ref encodingMap); + result = Base64Encoder.EncodeOneOptionallyPadTwo(bufferBytes + sourceIndex, ref encodingMap); } else { - result = EncodeAndPadOne(bufferBytes + sourceIndex, ref encodingMap); + result = Base64Encoder.EncodeTwoOptionallyPadOne(bufferBytes + sourceIndex, ref encodingMap); } Unsafe.WriteUnaligned(bufferBytes + destinationIndex, result); @@ -237,17 +234,14 @@ public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int bytesWritten = encodedLength; return OperationStatus.Done; - - FalseExit: - bytesWritten = 0; - return OperationStatus.DestinationTooSmall; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx512BW))] [CompExactlyDependsOn(typeof(Avx512Vbmi))] - private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + where TBase64Encoder : IBase64Encoder { // Reference for VBMI implementation : https://github.com/WojciechMula/base64simd/tree/master/encode // If we have AVX512 support, pick off 48 bytes at a time for as long as we can. @@ -263,7 +257,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, 0x0d0e0c0d, 0x10110f10, 0x13141213, 0x16171516, 0x191a1819, 0x1c1d1b1c, 0x1f201e1f, 0x22232122, 0x25262425, 0x28292728, 0x2b2c2a2b, 0x2e2f2d2e).AsSByte(); - Vector512 vbmiLookup = Vector512.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8).AsSByte(); + Vector512 vbmiLookup = Vector512.Create(TBase64Encoder.EncodingMap).AsSByte(); Vector512 maskAC = Vector512.Create((uint)0x0fc0fc00).AsUInt16(); Vector512 maskBB = Vector512.Create((uint)0x3f003f00); @@ -273,7 +267,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, AssertRead>(src, srcStart, sourceLength); // This algorithm requires AVX512VBMI support. - // Vbmi was first introduced in CannonLake and is avaialable from IceLake on. + // Vbmi was first introduced in CannonLake and is available from IceLake on. // str = [...|PONM|LKJI|HGFE|DCBA] Vector512 str = Vector512.Load(src).AsSByte(); @@ -320,7 +314,8 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx2))] - private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + where TBase64Encoder : IBase64Encoder { // If we have AVX2 support, pick off 24 bytes at a time for as long as we can. // But because we read 32 bytes at a time, ensure we have enough room to do a @@ -345,15 +340,7 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b 7, 6, 8, 7, 10, 9, 11, 10); - Vector256 lut = Vector256.Create( - 65, 71, -4, -4, - -4, -4, -4, -4, - -4, -4, -4, -4, - -19, -16, 0, 0, - 65, 71, -4, -4, - -4, -4, -4, -4, - -4, -4, -4, -4, - -19, -16, 0, 0); + Vector256 lut = TBase64Encoder.Avx2Lut; Vector256 maskAC = Vector256.Create(0x0fc0fc00).AsSByte(); Vector256 maskBB = Vector256.Create(0x003f03f0).AsSByte(); @@ -491,7 +478,8 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + where TBase64Encoder : IBase64Encoder { // C# implementation of https://github.com/aklomp/base64/blob/3a5add8652076612a8407627a42c768736a4263f/lib/arch/neon64/enc_loop.c Vector128 str1; @@ -504,7 +492,7 @@ private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes Vector128 tblEnc1 = Vector128.Create("ABCDEFGHIJKLMNOP"u8).AsByte(); Vector128 tblEnc2 = Vector128.Create("QRSTUVWXYZabcdef"u8).AsByte(); Vector128 tblEnc3 = Vector128.Create("ghijklmnopqrstuv"u8).AsByte(); - Vector128 tblEnc4 = Vector128.Create("wxyz0123456789+/"u8).AsByte(); + Vector128 tblEnc4 = TBase64Encoder.AdvSimdLut4; byte* src = srcBytes; byte* dest = destBytes; @@ -550,7 +538,8 @@ private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Ssse3))] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + where TBase64Encoder : IBase64Encoder { // If we have SSSE3 support, pick off 12 bytes at a time for as long as we can. // But because we read 16 bytes at a time, ensure we have enough room to do a @@ -561,7 +550,7 @@ private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destByt // The JIT won't hoist these "constants", so help it Vector128 shuffleVec = Vector128.Create(0x01020001, 0x04050304, 0x07080607, 0x0A0B090A).AsByte(); - Vector128 lut = Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, 0x0000F0ED).AsByte(); + Vector128 lut = TBase64Encoder.Ssse3AdvSimdLut; Vector128 maskAC = Vector128.Create(0x0fc0fc00).AsByte(); Vector128 maskBB = Vector128.Create(0x003f03f0).AsByte(); Vector128 shiftAC = Vector128.Create(0x04000040).AsUInt16(); @@ -672,7 +661,7 @@ private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destByt } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe uint Encode(byte* threeBytes, ref byte encodingMap) + internal static unsafe uint Encode(byte* threeBytes, ref byte encodingMap) { uint t0 = threeBytes[0]; uint t1 = threeBytes[1]; @@ -695,52 +684,76 @@ private static unsafe uint Encode(byte* threeBytes, ref byte encodingMap) } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe uint EncodeAndPadOne(byte* twoBytes, ref byte encodingMap) + internal const uint EncodingPad = '='; // '=', for padding + + internal const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 + + private readonly struct Base64Encoder : IBase64Encoder { - uint t0 = twoBytes[0]; - uint t1 = twoBytes[1]; + public static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; - uint i = (t0 << 16) | (t1 << 8); + public static Vector256 Avx2Lut => Vector256.Create( + 65, 71, -4, -4, + -4, -4, -4, -4, + -4, -4, -4, -4, + -19, -16, 0, 0, + 65, 71, -4, -4, + -4, -4, -4, -4, + -4, -4, -4, -4, + -19, -16, 0, 0); - uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); - uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); - uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + public static Vector128 AdvSimdLut4 => Vector128.Create("wxyz0123456789+/"u8).AsByte(); - if (BitConverter.IsLittleEndian) - { - return i0 | (i1 << 8) | (i2 << 16) | (EncodingPad << 24); - } - else - { - return (i0 << 24) | (i1 << 16) | (i2 << 8) | EncodingPad; - } - } + public static Vector128 Ssse3AdvSimdLut => Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, 0x0000F0ED).AsByte(); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe uint EncodeAndPadTwo(byte* oneByte, ref byte encodingMap) - { - uint t0 = oneByte[0]; + public static int IncrementPadTwo => 4; - uint i = t0 << 8; + public static int IncrementPadOne => 4; - uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); - uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + public static int GetMaxSrcLength(int srcLength, int destLength) => + srcLength <= MaximumEncodeLength && destLength >= GetMaxEncodedToUtf8Length(srcLength) ? srcLength : (destLength >> 2) * 3; - if (BitConverter.IsLittleEndian) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe uint EncodeOneOptionallyPadTwo(byte* oneByte, ref byte encodingMap) { - return i0 | (i1 << 8) | (EncodingPad << 16) | (EncodingPad << 24); + uint t0 = oneByte[0]; + + uint i = t0 << 8; + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + return i0 | (i1 << 8) | (EncodingPad << 16) | (EncodingPad << 24); + } + else + { + return (i0 << 24) | (i1 << 16) | (EncodingPad << 8) | EncodingPad; + } } - else + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe uint EncodeTwoOptionallyPadOne(byte* twoBytes, ref byte encodingMap) { - return (i0 << 24) | (i1 << 16) | (EncodingPad << 8) | EncodingPad; - } - } + uint t0 = twoBytes[0]; + uint t1 = twoBytes[1]; - internal const uint EncodingPad = '='; // '=', for padding + uint i = (t0 << 16) | (t1 << 8); - private const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); - internal static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; + if (BitConverter.IsLittleEndian) + { + return i0 | (i1 << 8) | (i2 << 16) | (EncodingPad << 24); + } + else + { + return (i0 << 24) | (i1 << 16) | (i2 << 8) | EncodingPad; + } + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs new file mode 100644 index 00000000000000..b84e58a8a63e64 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Buffers.Text +{ + public static partial class Base64Url + { + /*// Decode from utf8 => bytes + public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => throw new NotImplementedException(); + public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) => throw new NotImplementedException(); + + /// + /// Returns the maximum length (in bytes) of the result if you were to decode base 64 encoded text within a byte span of size "length". + /// + /// + /// Thrown when the specified is less than 0. + /// + public static int GetMaxDecodedFromUtf8Length(int length) => throw new NotImplementedException(); + + // IsValid + public static bool IsValid(ReadOnlySpan base64UrlText) => throw new NotImplementedException(); + public static bool IsValid(ReadOnlySpan base64UrlText, out int decodedLength) => throw new NotImplementedException(); + public static bool IsValid(ReadOnlySpan base64UrlTextUtf8) => throw new NotImplementedException(); + public static bool IsValid(ReadOnlySpan base64UrlTextUtf8, out int decodedLength) => throw new NotImplementedException(); + + // Up to this point, this is a mirror of System.Buffers.Text.Base64 + // Below are more helpers that bring over functionality similar to Convert.*Base64* + + // Encode to / decode from chars + public static bool TryDecodeFromChars(ReadOnlySpan chars, Span bytes, out int bytesWritten) => throw new NotImplementedException(); + + + // These are just accelerator methods. + // Should be efficiently implementable on top of the other ones in just a few lines. + + // Decode from chars => string + // Decode from chars => byte[] + // The names could also just be "Decode" without naming the return type + public static string DecodeToString(ReadOnlySpan chars, Encoding encoding) => throw new NotImplementedException(); + public static byte[] DecodeToByteArray(ReadOnlySpan chars) => throw new NotImplementedException();*/ + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs new file mode 100644 index 00000000000000..20b2e1256bb995 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs @@ -0,0 +1,326 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using static System.Buffers.Text.Base64; + +namespace System.Buffers.Text +{ + public static partial class Base64Url + { + /// + /// Encode the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// The number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. + /// The number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. + /// (default) when the input span contains the entire data to encode. + /// Set to when the source buffer contains the entirety of the data to encode. + /// Set to if this method is being called in a loop and if more input data may follow. + /// At the end of the loop, call this (potentially with an empty source buffer) passing . + /// It returns the OperationStatus enum values: + /// - Done - on successful processing of the entire input span + /// - DestinationTooSmall - if there is not enough space in the output span to fit the encoded input + /// - NeedMoreData - only if is + /// It does not return InvalidData since that is not possible for base64 encoding. + /// + /// The output will not be padded even if the input is not a multiple of 3. + public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan source, Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + EncodeToUtf8(source, destination, out bytesConsumed, out bytesWritten, isFinalBlock); + + /// + /// Returns the length (in bytes) of the result if you were to encode binary data within a byte span of size "length". + /// + /// + /// Thrown when the specified is less than 0 or larger than 1610612733 (since encode inflates the data by 4/3). + /// + public static int GetEncodedLength(int bytesLength) + { + if ((uint)bytesLength > Base64.MaximumEncodeLength) + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); + + int remainder = bytesLength % 3; + + return bytesLength / 3 * 4 + (remainder > 0 ? remainder + 1 : 0); // if remainder is 1 or 2, the encoded length will be 1 byte longer. + } + + /// + /// Encode the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// The number of bytes written into the destination span. This can be used to slice the output for subsequent calls, if necessary. + /// The output will not be padded even if the input is not a multiple of 3. + public static int EncodeToUtf8(ReadOnlySpan source, Span destination) + { + EncodeToUtf8(source, destination, out _, out int written); + + return written; + } + + /// + /// Encode the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output byte array which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// The output will not be padded even if the input is not a multiple of 3. + public static byte[] EncodeToUtf8(ReadOnlySpan source) + { + if (source.Length == 0) + { + return Array.Empty(); + } + + Span destination = stackalloc byte[GetEncodedLength(source.Length)]; // or new byte[GetEncodedLength(source.Length)] + EncodeToUtf8(source, destination, out _, out int written); + + return destination.Slice(0, written).ToArray(); + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded chars in Base64Url. + /// The number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. + /// The number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. + /// (default) when the input span contains the entire data to encode. + /// Set to when the source buffer contains the entirety of the data to encode. + /// Set to if this method is being called in a loop and if more input data may follow. + /// At the end of the loop, call this (potentially with an empty source buffer) passing . + /// It returns the OperationStatus enum values: + /// - Done - on successful processing of the entire input span + /// - DestinationTooSmall - if there is not enough space in the output span to fit the encoded input + /// - NeedMoreData - only if is + /// It does not return InvalidData since that is not possible for base64 encoding. + /// + /// The output will not be padded even if the input is not a multiple of 3. + public static OperationStatus EncodeToChars(ReadOnlySpan source, Span destination, out int bytesConsumed, out int charsWritten, bool isFinalBlock = true) + { + if (source.Length == 0) + { + bytesConsumed = 0; + charsWritten = 0; + return OperationStatus.Done; + } + + return EncodeToUtf8(source, MemoryMarshal.AsBytes(destination), out bytesConsumed, out charsWritten, isFinalBlock); + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded chars in Base64Url. + /// The number of bytes written into the destination span. This can be used to slice the output for subsequent calls, if necessary. + /// The output will not be padded even if the input is not a multiple of 3. + public static int EncodeToChars(ReadOnlySpan source, Span destination) + { + EncodeToUtf8(source, MemoryMarshal.AsBytes(destination), out _, out int written); + return written; + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// A char array which contains the result of the operation, i.e. the UTF-8 encoded chars in Base64Url. + /// The output will not be padded even if the input is not a multiple of 3. + public static char[] EncodeToChars(ReadOnlySpan source) + { + if (source.Length == 0) + { + return Array.Empty(); + } + + Span destination = stackalloc char[GetEncodedLength(source.Length)]; + EncodeToUtf8(source, MemoryMarshal.AsBytes(destination), out _, out int charsWritten); + + return destination.Slice(0, charsWritten).ToArray(); + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// A string which contains the result of the operation, i.e. the UTF-8 encoded chars in Base64Url. + /// The output will not be padded even if the input is not a multiple of 3. + public static string EncodeToString(ReadOnlySpan source) + { + if (source.Length == 0) + { + return string.Empty; + } + + Span destination = stackalloc byte[GetEncodedLength(source.Length)]; + EncodeToUtf8(source, destination, out _, out int charsWritten); + + return destination.Slice(0, charsWritten).ToString(); // Encoding.UTF8.GetString(utf8.Slice(0, bytesWritten)) + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded chars in Base64Url. + /// The number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. + /// if chars encoded successfully, otherwise . + /// The output will not be padded even if the input is not a multiple of 3. + public static bool TryEncodeToChars(ReadOnlySpan source, Span destination, out int charsWritten) + { + OperationStatus status = EncodeToUtf8(source, MemoryMarshal.AsBytes(destination), out _, out charsWritten); + + return status == OperationStatus.Done; + } + + /// + /// Encode the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// The number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. + /// if bytes encoded successfully, otherwise . + /// The output will not be padded even if the input is not a multiple of 3. + public static bool TryEncodeToUtf8(ReadOnlySpan source, Span destination, out int charsWritten) + { + OperationStatus status = EncodeToUtf8(source, destination, out _, out charsWritten); + + return status == OperationStatus.Done; + } + + /// + /// Encode the span of binary data (in-place) into UTF-8 encoded text represented as base 64. + /// The encoded text output is larger than the binary data contained in the input (the operation inflates the data). + /// + /// The input span which contains binary data that needs to be encoded. + /// It needs to be large enough to fit the result of the operation. + /// The amount of binary data contained within the buffer that needs to be encoded + /// (and needs to be smaller than the buffer length). + /// The number of bytes written into the buffer. + /// if bytes encoded successfully, otherwise . + public static unsafe bool TryEncodeToUtf8InPlace(Span buffer, int dataLength, out int bytesWritten) + { + if (buffer.IsEmpty) + { + bytesWritten = 0; + return true; + } + + fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) + { + int encodedLength = GetEncodedLength(dataLength); + if (buffer.Length < encodedLength) + { + bytesWritten = 0; + return false; + } + + int leftover = dataLength % 3; // how many bytes left after packs of 3 + + uint destinationIndex = leftover > 0 ? (uint)(encodedLength - leftover - 1) : (uint)(encodedLength - 4); + uint sourceIndex = (uint)(dataLength - leftover); + uint result = 0; + ref byte encodingMap = ref MemoryMarshal.GetReference(Base64UrlEncoder.EncodingMap); + + // encode last pack to avoid conditional in the main loop + if (leftover != 0) + { + if (leftover == 1) + { + result = Base64UrlEncoder.EncodeOneOptionallyPadTwo(bufferBytes + sourceIndex, ref encodingMap); + } + else + { + result = Base64UrlEncoder.EncodeTwoOptionallyPadOne(bufferBytes + sourceIndex, ref encodingMap); + } + + Unsafe.WriteUnaligned(bufferBytes + destinationIndex, result); + destinationIndex -= 4; + } + + sourceIndex -= 3; + while ((int)sourceIndex >= 0) + { + result = Encode(bufferBytes + sourceIndex, ref encodingMap); + Unsafe.WriteUnaligned(bufferBytes + destinationIndex, result); + destinationIndex -= 4; + sourceIndex -= 3; + } + + bytesWritten = encodedLength; + return true; + } + } + + private readonly struct Base64UrlEncoder : IBase64Encoder + { + public static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; + + public static Vector256 Avx2Lut => Vector256.Create( + 65, 71, -4, -4, + -4, -4, -4, -4, + -4, -4, -4, -4, + -17, 32, 0, 0, + 65, 71, -4, -4, + -4, -4, -4, -4, + -4, -4, -4, -4, + -17, 32, 0, 0); + + public static Vector128 AdvSimdLut4 => Vector128.Create("wxyz0123456789-_"u8).AsByte(); + + public static Vector128 Ssse3AdvSimdLut => Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, 0x000020EF).AsByte(); + + public static int IncrementPadTwo => 2; + + public static int IncrementPadOne => 3; + + public static int GetMaxSrcLength(int srcLength, int destLength) => + srcLength <= MaximumEncodeLength && destLength >= GetEncodedLength(srcLength) ? srcLength : (destLength >> 2) * 3 + destLength % 4; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe uint EncodeOneOptionallyPadTwo(byte* oneByte, ref byte encodingMap) + { + uint t0 = oneByte[0]; + + uint i = t0 << 8; + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + return i0 | (i1 << 8); + } + else + { + return (i0 << 8) | i1; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe uint EncodeTwoOptionallyPadOne(byte* twoBytes, ref byte encodingMap) + { + uint t0 = twoBytes[0]; + uint t1 = twoBytes[1]; + + uint i = (t0 << 16) | (t1 << 8); + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + return i0 | (i1 << 8) | (i2 << 16); + } + else + { + return (i0 << 16) | (i1 << 8) | i2; + } + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs new file mode 100644 index 00000000000000..502071837dacb2 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Buffers.Text +{ + public static partial class Base64Url + { + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// TODO : Update remarks + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64UrlText) => + Base64.IsValid(base64UrlText, out _); + + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// TODO : Update remarks + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64UrlText, out int decodedLength) => + Base64.IsValid(base64UrlText, out decodedLength); + + /// Validates that the specified span of UTF-8 text is comprised of valid base-64 encoded data. + /// A span of UTF-8 text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// TODO : Update remarks + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan utf8Base64UrlText) => + Base64.IsValid(utf8Base64UrlText, out _); + + /// Validates that the specified span of UTF-8 text is comprised of valid base-64 encoded data. + /// A span of UTF-8 text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input UTF-8 text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// TODO : Update remarks + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan utf8Base64UrlText, out int decodedLength) => + Base64.IsValid(utf8Base64UrlText, out decodedLength); + + private static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) + { + if (length == 1) + { + decodedLength = 0; + return false; + } + + // Padding is optional for Base64Url, so need to account remainder. + int remainder = length % 4; + decodedLength = (int)((uint)length / 4 * 3) + (remainder > 0 ? remainder - 1 : 0) - paddingCount; + return true; + } + + private const uint UrlEncodingPad = '%'; // url padding + + private readonly struct Base64UrlCharValidatable : Base64.IBase64Validatable + { + private static readonly SearchValues s_validBase64UrlChars = SearchValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64UrlChars); + public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(char value) => value == Base64.EncodingPad || value == UrlEncodingPad; + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64Url.ValidateAndDecodeLength(length, paddingCount, out decodedLength); + } + + private readonly struct Base64UrlByteValidatable : Base64.IBase64Validatable + { + private static readonly SearchValues s_validBase64UrlChars = SearchValues.Create(Base64UrlEncoder.EncodingMap); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64UrlChars); + public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(byte value) => value == Base64.EncodingPad || value == UrlEncodingPad; + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64Url.ValidateAndDecodeLength(length, paddingCount, out decodedLength); + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 22071725a23520..8f8e22ae94263b 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -53,7 +53,7 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8) => public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) => IsValid(base64TextUtf8, out decodedLength); - private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + internal static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) where TBase64Validatable : IBase64Validatable { int length = 0, paddingCount = 0; @@ -116,14 +116,15 @@ private static bool IsValid(ReadOnlySpan base64Text, o break; } - if (length % 4 != 0) + if (!TBase64Validatable.ValidateAndDecodeLength(length, paddingCount, out decodedLength)) { goto Fail; } + + return true; } - // Remove padding to get exact length. - decodedLength = (int)((uint)length / 4 * 3) - paddingCount; + decodedLength = 0; return true; Fail: @@ -131,11 +132,25 @@ private static bool IsValid(ReadOnlySpan base64Text, o return false; } - private interface IBase64Validatable + private static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) + { + if (length % 4 == 0) + { + // Remove padding to get exact length. + decodedLength = (int)((uint)length / 4 * 3) - paddingCount; + return true; + } + + decodedLength = 0; + return false; + } + + internal interface IBase64Validatable { static abstract int IndexOfAnyExcept(ReadOnlySpan span); static abstract bool IsWhiteSpace(T value); static abstract bool IsEncodingPad(T value); + static abstract bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength); } private readonly struct Base64CharValidatable : IBase64Validatable @@ -145,15 +160,19 @@ private interface IBase64Validatable public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value); public static bool IsEncodingPad(char value) => value == EncodingPad; + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64.ValidateAndDecodeLength(length, paddingCount, out decodedLength); } private readonly struct Base64ByteValidatable : IBase64Validatable { - private static readonly SearchValues s_validBase64Chars = SearchValues.Create(EncodingMap); + private static readonly SearchValues s_validBase64Chars = SearchValues.Create(Base64Encoder.EncodingMap); public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value); public static bool IsEncodingPad(byte value) => value == EncodingPad; + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64.ValidateAndDecodeLength(length, paddingCount, out decodedLength); } } }