diff --git a/src/Http/Http/src/Features/QueryFeature.cs b/src/Http/Http/src/Features/QueryFeature.cs index dd3a440202e8..af1707aeeff7 100644 --- a/src/Http/Http/src/Features/QueryFeature.cs +++ b/src/Http/Http/src/Features/QueryFeature.cs @@ -2,13 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Runtime.Intrinsics; -using System.Runtime.Intrinsics.X86; using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Http.Features @@ -19,7 +16,7 @@ namespace Microsoft.AspNetCore.Http.Features public class QueryFeature : IQueryFeature { // Lambda hoisted to static readonly field to improve inlining https://github.com/dotnet/roslyn/issues/13624 - private readonly static Func _nullRequestFeature = f => null; + private static readonly Func _nullRequestFeature = f => null; private FeatureReferences _features; @@ -113,48 +110,10 @@ public IQueryCollection Query } var accumulator = new KvpAccumulator(); - var query = queryString.AsSpan(); - - if (query[0] == '?') + var enumerable = new QueryStringEnumerable(queryString.AsSpan()); + foreach (var pair in enumerable) { - query = query[1..]; - } - - while (!query.IsEmpty) - { - var delimiterIndex = query.IndexOf('&'); - - var querySegment = delimiterIndex >= 0 - ? query.Slice(0, delimiterIndex) - : query; - - var equalIndex = querySegment.IndexOf('='); - - if (equalIndex >= 0) - { - var name = SpanHelper.ReplacePlusWithSpace(querySegment.Slice(0, equalIndex)); - var value = SpanHelper.ReplacePlusWithSpace(querySegment.Slice(equalIndex + 1)); - - accumulator.Append( - Uri.UnescapeDataString(name), - Uri.UnescapeDataString(value)); - } - else - { - if (!querySegment.IsEmpty) - { - var name = SpanHelper.ReplacePlusWithSpace(querySegment); - - accumulator.Append(Uri.UnescapeDataString(name)); - } - } - - if (delimiterIndex < 0) - { - break; - } - - query = query.Slice(delimiterIndex + 1); + accumulator.Append(pair.DecodeName(), pair.DecodeValue()); } return accumulator.HasValues @@ -171,8 +130,8 @@ internal struct KvpAccumulator private AdaptiveCapacityDictionary _accumulator; private AdaptiveCapacityDictionary> _expandingAccumulator; - public void Append(ReadOnlySpan key, ReadOnlySpan value = default) - => Append(key.ToString(), value.IsEmpty ? string.Empty : value.ToString()); + public void Append(ReadOnlySpan key, ReadOnlySpan value) + => Append(key.ToString(), value.ToString()); /// /// This API supports infrastructure and is not intended to be used @@ -263,58 +222,5 @@ public AdaptiveCapacityDictionary GetResults() return _accumulator ?? new AdaptiveCapacityDictionary(0, StringComparer.OrdinalIgnoreCase); } } - - private static class SpanHelper - { - private static readonly SpanAction s_replacePlusWithSpace = ReplacePlusWithSpaceCore; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe string ReplacePlusWithSpace(ReadOnlySpan span) - { - fixed (char* ptr = &MemoryMarshal.GetReference(span)) - { - return string.Create(span.Length, (IntPtr)ptr, s_replacePlusWithSpace); - } - } - - private static unsafe void ReplacePlusWithSpaceCore(Span buffer, IntPtr state) - { - fixed (char* ptr = &MemoryMarshal.GetReference(buffer)) - { - var input = (ushort*)state.ToPointer(); - var output = (ushort*)ptr; - - var i = (nint)0; - var n = (nint)(uint)buffer.Length; - - if (Sse41.IsSupported && n >= Vector128.Count) - { - var vecPlus = Vector128.Create((ushort)'+'); - var vecSpace = Vector128.Create((ushort)' '); - - do - { - var vec = Sse2.LoadVector128(input + i); - var mask = Sse2.CompareEqual(vec, vecPlus); - var res = Sse41.BlendVariable(vec, vecSpace, mask); - Sse2.Store(output + i, res); - i += Vector128.Count; - } while (i <= n - Vector128.Count); - } - - for (; i < n; ++i) - { - if (input[i] != '+') - { - output[i] = input[i]; - } - else - { - output[i] = ' '; - } - } - } - } - } } } diff --git a/src/Http/WebUtilities/src/Microsoft.AspNetCore.WebUtilities.csproj b/src/Http/WebUtilities/src/Microsoft.AspNetCore.WebUtilities.csproj index 3e3560224e2d..f5ef216dccf5 100644 --- a/src/Http/WebUtilities/src/Microsoft.AspNetCore.WebUtilities.csproj +++ b/src/Http/WebUtilities/src/Microsoft.AspNetCore.WebUtilities.csproj @@ -4,7 +4,7 @@ ASP.NET Core utilities, such as for working with forms, multipart messages, and query strings. $(DefaultNetCoreTargetFramework) true - $(DefineConstants);WebEncoders_In_WebUtilities + $(DefineConstants);WebEncoders_In_WebUtilities;QueryStringEnumerable_In_WebUtilities true true aspnetcore @@ -13,6 +13,7 @@ + diff --git a/src/Http/WebUtilities/src/PublicAPI.Unshipped.txt b/src/Http/WebUtilities/src/PublicAPI.Unshipped.txt index bf67ce991730..85d85eece84f 100644 --- a/src/Http/WebUtilities/src/PublicAPI.Unshipped.txt +++ b/src/Http/WebUtilities/src/PublicAPI.Unshipped.txt @@ -2,6 +2,17 @@ *REMOVED*static Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseNullableQuery(string! queryString) -> System.Collections.Generic.Dictionary? Microsoft.AspNetCore.WebUtilities.FileBufferingReadStream.MemoryThreshold.get -> int Microsoft.AspNetCore.WebUtilities.FileBufferingWriteStream.MemoryThreshold.get -> int +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair.DecodeName() -> System.ReadOnlySpan +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair.DecodeValue() -> System.ReadOnlySpan +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair.EncodedName.get -> System.ReadOnlySpan +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair.EncodedValue.get -> System.ReadOnlySpan +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.Enumerator +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.Enumerator.Current.get -> Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.EncodedNameValuePair +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.Enumerator.MoveNext() -> bool +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.GetEnumerator() -> Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.Enumerator +Microsoft.AspNetCore.WebUtilities.QueryStringEnumerable.QueryStringEnumerable(System.ReadOnlySpan queryString) -> void override Microsoft.AspNetCore.WebUtilities.BufferedReadStream.ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask override Microsoft.AspNetCore.WebUtilities.FileBufferingWriteStream.WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask static Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseNullableQuery(string? queryString) -> System.Collections.Generic.Dictionary? diff --git a/src/Http/WebUtilities/src/QueryHelpers.cs b/src/Http/WebUtilities/src/QueryHelpers.cs index 64a2ba2db6ba..c0b8f81562fc 100644 --- a/src/Http/WebUtilities/src/QueryHelpers.cs +++ b/src/Http/WebUtilities/src/QueryHelpers.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Text.Encodings.Web; +using Microsoft.AspNetCore.Internal; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.WebUtilities @@ -172,59 +173,11 @@ public static Dictionary ParseQuery(string? queryString) public static Dictionary? ParseNullableQuery(string? queryString) { var accumulator = new KeyValueAccumulator(); + var enumerable = new QueryStringEnumerable(queryString); - if (string.IsNullOrEmpty(queryString) || queryString == "?") + foreach (var pair in enumerable) { - return null; - } - - int scanIndex = 0; - if (queryString[0] == '?') - { - scanIndex = 1; - } - - int textLength = queryString.Length; - int equalIndex = queryString.IndexOf('='); - if (equalIndex == -1) - { - equalIndex = textLength; - } - while (scanIndex < textLength) - { - int delimiterIndex = queryString.IndexOf('&', scanIndex); - if (delimiterIndex == -1) - { - delimiterIndex = textLength; - } - if (equalIndex < delimiterIndex) - { - while (scanIndex != equalIndex && char.IsWhiteSpace(queryString[scanIndex])) - { - ++scanIndex; - } - string name = queryString.Substring(scanIndex, equalIndex - scanIndex); - string value = queryString.Substring(equalIndex + 1, delimiterIndex - equalIndex - 1); - accumulator.Append( - Uri.UnescapeDataString(name.Replace('+', ' ')), - Uri.UnescapeDataString(value.Replace('+', ' '))); - equalIndex = queryString.IndexOf('=', delimiterIndex); - if (equalIndex == -1) - { - equalIndex = textLength; - } - } - else - { - if (delimiterIndex > scanIndex) - { - string name = queryString.Substring(scanIndex, delimiterIndex - scanIndex); - accumulator.Append( - Uri.UnescapeDataString(name.Replace('+', ' ')), - string.Empty); - } - } - scanIndex = delimiterIndex + 1; + accumulator.Append(pair.DecodeName().ToString(), pair.DecodeValue().ToString()); } if (!accumulator.HasValues) diff --git a/src/Shared/QueryStringEnumerable.cs b/src/Shared/QueryStringEnumerable.cs new file mode 100644 index 000000000000..0a239498de5e --- /dev/null +++ b/src/Shared/QueryStringEnumerable.cs @@ -0,0 +1,208 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +#if QueryStringEnumerable_In_WebUtilities +namespace Microsoft.AspNetCore.WebUtilities +#else +namespace Microsoft.AspNetCore.Internal +#endif +{ + /// + /// An enumerable that can supply the name/value pairs from a URI query string. + /// +#if QueryStringEnumerable_In_WebUtilities + public +#else + internal +#endif + readonly ref struct QueryStringEnumerable + { + private readonly ReadOnlySpan _queryString; + + /// + /// Constructs an instance of . + /// + /// The query string. + public QueryStringEnumerable(ReadOnlySpan queryString) + { + _queryString = queryString; + } + + /// + /// Retrieves an object that can iterate through the name/value pairs in the query string. + /// + /// An object that can iterate through the name/value pairs in the query string. + public Enumerator GetEnumerator() + => new Enumerator(_queryString); + + /// + /// Represents a single name/value pair extracted from a query string during enumeration. + /// + public readonly ref struct EncodedNameValuePair + { + /// + /// Gets the name from this name/value pair in its original encoded form. + /// To get the decoded string, call . + /// + public readonly ReadOnlySpan EncodedName { get; } + + /// + /// Gets the value from this name/value pair in its original encoded form. + /// To get the decoded string, call . + /// + public readonly ReadOnlySpan EncodedValue { get; } + + internal EncodedNameValuePair(ReadOnlySpan encodedName, ReadOnlySpan encodedValue) + { + EncodedName = encodedName; + EncodedValue = encodedValue; + } + + /// + /// Decodes the name from this name/value pair. + /// + /// Characters representing the decoded name. + public ReadOnlySpan DecodeName() + => Decode(EncodedName); + + /// + /// Decodes the value from this name/value pair. + /// + /// Characters representing the decoded value. + public ReadOnlySpan DecodeValue() + => Decode(EncodedValue); + + private static ReadOnlySpan Decode(ReadOnlySpan chars) + { + // If the value is short, it's cheap to check up front if it really needs decoding. If it doesn't, + // then we can save some allocations. + return chars.Length < 16 && chars.IndexOfAny('%', '+') < 0 + ? chars + : Uri.UnescapeDataString(SpanHelper.ReplacePlusWithSpace(chars)); + } + } + + /// + /// An enumerator that supplies the name/value pairs from a URI query string. + /// + public ref struct Enumerator + { + private ReadOnlySpan _query; + + internal Enumerator(ReadOnlySpan query) + { + Current = default; + _query = query.IsEmpty || query[0] != '?' + ? query + : query.Slice(1); + } + + /// + /// Gets the currently referenced key/value pair in the query string being enumerated. + /// + public EncodedNameValuePair Current { get; private set; } + + /// + /// Moves to the next key/value pair in the query string being enumerated. + /// + /// True if there is another key/value pair, otherwise false. + public bool MoveNext() + { + while (!_query.IsEmpty) + { + // Chomp off the next segment + ReadOnlySpan segment; + var delimiterIndex = _query.IndexOf('&'); + if (delimiterIndex >= 0) + { + segment = _query.Slice(0, delimiterIndex); + _query = _query.Slice(delimiterIndex + 1); + } + else + { + segment = _query; + _query = default; + } + + // If it's nonempty, emit it + var equalIndex = segment.IndexOf('='); + if (equalIndex >= 0) + { + Current = new EncodedNameValuePair( + segment.Slice(0, equalIndex), + segment.Slice(equalIndex + 1)); + return true; + } + else if (!segment.IsEmpty) + { + Current = new EncodedNameValuePair(segment, default); + return true; + } + } + + Current = default; + return false; + } + } + + private static class SpanHelper + { + private static readonly SpanAction s_replacePlusWithSpace = ReplacePlusWithSpaceCore; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe string ReplacePlusWithSpace(ReadOnlySpan span) + { + fixed (char* ptr = &MemoryMarshal.GetReference(span)) + { + return string.Create(span.Length, (IntPtr)ptr, s_replacePlusWithSpace); + } + } + + private static unsafe void ReplacePlusWithSpaceCore(Span buffer, IntPtr state) + { + fixed (char* ptr = &MemoryMarshal.GetReference(buffer)) + { + var input = (ushort*)state.ToPointer(); + var output = (ushort*)ptr; + + var i = (nint)0; + var n = (nint)(uint)buffer.Length; + + if (Sse41.IsSupported && n >= Vector128.Count) + { + var vecPlus = Vector128.Create('+'); + var vecSpace = Vector128.Create(' '); + + do + { + var vec = Sse2.LoadVector128(input + i); + var mask = Sse2.CompareEqual(vec, vecPlus); + var res = Sse41.BlendVariable(vec, vecSpace, mask); + Sse2.Store(output + i, res); + i += Vector128.Count; + } while (i <= n - Vector128.Count); + } + + for (; i < n; ++i) + { + if (input[i] != '+') + { + output[i] = input[i]; + } + else + { + output[i] = ' '; + } + } + } + } + } + } +} diff --git a/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj b/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj index 694f29b8d51c..0ed8170ccb9b 100644 --- a/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj +++ b/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj @@ -31,6 +31,7 @@ + diff --git a/src/Shared/test/Shared.Tests/QueryStringEnumerableTest.cs b/src/Shared/test/Shared.Tests/QueryStringEnumerableTest.cs new file mode 100644 index 000000000000..ad6ad0337c91 --- /dev/null +++ b/src/Shared/test/Shared.Tests/QueryStringEnumerableTest.cs @@ -0,0 +1,129 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.AspNetCore.Internal +{ + public class QueryStringEnumerableTest + { + [Fact] + public void ParseQueryWithUniqueKeysWorks() + { + Assert.Collection(Parse("?key1=value1&key2=value2"), + kvp => AssertKeyValuePair("key1", "value1", kvp), + kvp => AssertKeyValuePair("key2", "value2", kvp)); + } + + [Fact] + public void ParseQueryWithoutQuestionmarkWorks() + { + Assert.Collection(Parse("key1=value1&key2=value2"), + kvp => AssertKeyValuePair("key1", "value1", kvp), + kvp => AssertKeyValuePair("key2", "value2", kvp)); + } + + [Fact] + public void ParseQueryWithDuplicateKeysGroups() + { + Assert.Collection(Parse("?key1=valueA&key2=valueB&key1=valueC"), + kvp => AssertKeyValuePair("key1", "valueA", kvp), + kvp => AssertKeyValuePair("key2", "valueB", kvp), + kvp => AssertKeyValuePair("key1", "valueC", kvp)); + } + + [Fact] + public void ParseQueryWithEmptyValuesWorks() + { + Assert.Collection(Parse("?key1=&key2="), + kvp => AssertKeyValuePair("key1", string.Empty, kvp), + kvp => AssertKeyValuePair("key2", string.Empty, kvp)); + } + + [Fact] + public void ParseQueryWithEmptyKeyWorks() + { + Assert.Collection(Parse("?=value1&="), + kvp => AssertKeyValuePair(string.Empty, "value1", kvp), + kvp => AssertKeyValuePair(string.Empty, string.Empty, kvp)); + } + + [Fact] + public void ParseQueryWithEncodedKeyWorks() + { + Assert.Collection(Parse("?fields+%5BtodoItems%5D"), + kvp => AssertKeyValuePair("fields+%5BtodoItems%5D", string.Empty, kvp)); + } + + [Fact] + public void ParseQueryWithEncodedValueWorks() + { + Assert.Collection(Parse("?=fields+%5BtodoItems%5D"), + kvp => AssertKeyValuePair(string.Empty, "fields+%5BtodoItems%5D", kvp)); + } + + [Theory] + [InlineData("?")] + [InlineData("")] + [InlineData(null)] + [InlineData("?&&")] + public void ParseEmptyOrNullQueryWorks(string queryString) + { + Assert.Empty(Parse(queryString)); + } + + [Fact] + public void ParseIgnoresEmptySegments() + { + Assert.Collection(Parse("?&key1=value1&&key2=value2&"), + kvp => AssertKeyValuePair("key1", "value1", kvp), + kvp => AssertKeyValuePair("key2", "value2", kvp)); + } + + [Theory] + [InlineData("?a+b=c+d", "a b", "c d")] + [InlineData("? %5Bkey%5D = %26value%3D ", " [key] ", " &value= ")] + [InlineData("?+", " ", "")] + [InlineData("?=+", "", " ")] + public void DecodingWorks(string queryString, string expectedDecodedName, string expectedDecodedValue) + { + foreach (var kvp in new QueryStringEnumerable(queryString)) + { + Assert.Equal(expectedDecodedName, kvp.DecodeName().ToString()); + Assert.Equal(expectedDecodedValue, kvp.DecodeValue().ToString()); + } + } + + [Fact] + public void DecodingRetainsSpansIfDecodingNotNeeded() + { + foreach (var kvp in new QueryStringEnumerable("?key=value")) + { + Assert.True(MemoryExtensions.Overlaps(kvp.EncodedName, kvp.DecodeName(), out var nameOffset)); + Assert.True(MemoryExtensions.Overlaps(kvp.EncodedValue, kvp.DecodeValue(), out var valueOffset)); + Assert.Equal(0, nameOffset); + Assert.Equal(0, valueOffset); + } + } + + private static void AssertKeyValuePair(string expectedKey, string expectedValue, (string key, string value) actual) + { + Assert.Equal(expectedKey, actual.key); + Assert.Equal(expectedValue, actual.value); + } + + private static IReadOnlyList<(string key, string value)> Parse(string query) + { + var result = new List<(string key, string value)>(); + var enumerable = new QueryStringEnumerable(query); + foreach (var pair in enumerable) + { + result.Add((pair.EncodedName.ToString(), pair.EncodedValue.ToString())); + } + + return result; + } + } +}