diff --git a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj index 884967d73d78..8c3fd43461fb 100644 --- a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj +++ b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj @@ -69,6 +69,7 @@ + diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs index 66282381cb8c..2fafdb58a57b 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs @@ -21,6 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; public class HttpParser : IHttpParser where TRequestHandler : IHttpHeadersHandler, IHttpRequestLineHandler { private readonly bool _showErrorDetails; + private readonly bool _disableHttp1LineFeedTerminators; /// /// This API supports framework infrastructure and is not intended to be used @@ -34,9 +35,14 @@ public HttpParser() : this(showErrorDetails: true) /// This API supports framework infrastructure and is not intended to be used /// directly from application code. /// - public HttpParser(bool showErrorDetails) + public HttpParser(bool showErrorDetails) : this(showErrorDetails, AppContext.TryGetSwitch(KestrelServerOptions.DisableHttp1LineFeedTerminatorsSwitchKey, out var disabled) && disabled) + { + } + + internal HttpParser(bool showErrorDetails, bool disableHttp1LineFeedTerminators) { _showErrorDetails = showErrorDetails; + _disableHttp1LineFeedTerminators = disableHttp1LineFeedTerminators; } // byte types don't have a data type annotation so we pre-cast them; to avoid in-place casts @@ -135,9 +141,15 @@ private void ParseRequestLine(TRequestHandler handler, ReadOnlySpan reques // Version + CR is 9 bytes which should take us to .Length // LF should have been dropped prior to method call - if ((uint)offset + 9 != (uint)requestLine.Length || requestLine[offset + sizeof(ulong)] != ByteCR) + if ((uint)offset + 9 != (uint)requestLine.Length || requestLine[offset + 8] != ByteCR) { - RejectRequestLine(requestLine); + // LF should have been dropped prior to method call + // If !_disableHttp1LineFeedTerminators and offset + 8 is .Length, + // then requestLine is valid since it means LF was the next char + if (_disableHttp1LineFeedTerminators || (uint)offset + 8 != (uint)requestLine.Length) + { + RejectRequestLine(requestLine); + } } // Version @@ -164,135 +176,142 @@ public bool ParseHeaders(TRequestHandler handler, ref SequenceReader reade { while (!reader.End) { + // Check if the reader's span contains an LF to skip the reader if possible var span = reader.UnreadSpan; - while (span.Length > 0) + + // Fast path, CR/LF at the beginning + if (span.Length >= 2 && span[0] == ByteCR && span[1] == ByteLF) { - byte ch1; - var ch2 = (byte)0; - var readAhead = 0; + reader.Advance(2); + handler.OnHeadersComplete(endStream: false); + return true; + } - // Fast path, we're still looking at the same span - if (span.Length >= 2) - { - ch1 = span[0]; - ch2 = span[1]; - } - else if (reader.TryRead(out ch1)) // Possibly split across spans - { - // Note if we read ahead by 1 or 2 bytes - readAhead = (reader.TryRead(out ch2)) ? 2 : 1; - } + var foundCrlf = false; - if (ch1 == ByteCR) + var lfOrCrIndex = span.IndexOfAny(ByteCR, ByteLF); + if (lfOrCrIndex >= 0) + { + if (span[lfOrCrIndex] == ByteCR) { - // Check for final CRLF. - if (ch2 == ByteLF) - { - // If we got 2 bytes from the span directly so skip ahead 2 so that - // the reader's state matches what we expect - if (readAhead == 0) - { - reader.Advance(2); - } + // We got a CR. Is this a CR/LF sequence? + var crIndex = lfOrCrIndex; + reader.Advance(crIndex + 1); - // Double CRLF found, so end of headers. - handler.OnHeadersComplete(endStream: false); - return true; + bool hasDataAfterCr; + + if ((uint)span.Length > (uint)(crIndex + 1) && span[crIndex + 1] == ByteLF) + { + // CR/LF in the same span (common case) + span = span.Slice(0, crIndex); + foundCrlf = true; } - else if (readAhead == 1) + else if ((hasDataAfterCr = reader.TryPeek(out byte lfMaybe)) && lfMaybe == ByteLF) { - // Didn't read 2 bytes, reset the reader so we don't consume anything - reader.Rewind(1); - return false; + // CR/LF but split between spans + span = span.Slice(0, span.Length - 1); + foundCrlf = true; } - - Debug.Assert(readAhead == 0 || readAhead == 2); - // Headers don't end in CRLF line. - - KestrelBadHttpRequestException.Throw(RequestRejectionReason.InvalidRequestHeadersNoCRLF); - } - - var length = 0; - // We only need to look for the end if we didn't read ahead; otherwise there isn't enough in - // in the span to contain a header. - if (readAhead == 0) - { - length = span.IndexOfAny(ByteCR, ByteLF); - // If not found length with be -1; casting to uint will turn it to uint.MaxValue - // which will be larger than any possible span.Length. This also serves to eliminate - // the bounds check for the next lookup of span[length] - if ((uint)length < (uint)span.Length) + else { - // Early memory read to hide latency - var expectedCR = span[length]; - // Correctly has a CR, move to next - length++; - - if (expectedCR != ByteCR) + // What's after the CR? + if (!hasDataAfterCr) { - // Sequence needs to be CRLF not LF first. - RejectRequestHeader(span[..length]); + // No more chars after CR? Don't consume an incomplete header + reader.Rewind(crIndex + 1); + return false; } - - if ((uint)length < (uint)span.Length) + else if (crIndex == 0) { - // Early memory read to hide latency - var expectedLF = span[length]; - // Correctly has a LF, move to next - length++; - - if (expectedLF != ByteLF || - length < 5 || - // Exclude the CRLF from the headerLine and parse the header name:value pair - !TryTakeSingleHeader(handler, span[..(length - 2)])) - { - // Sequence needs to be CRLF and not contain an inner CR not part of terminator. - // Less than min possible headerSpan of 5 bytes a:b\r\n - // Not parsable as a valid name:value header pair. - RejectRequestHeader(span[..length]); - } - - // Read the header successfully, skip the reader forward past the headerSpan. - span = span.Slice(length); - reader.Advance(length); + // CR followed by something other than LF + KestrelBadHttpRequestException.Throw(RequestRejectionReason.InvalidRequestHeadersNoCRLF); } else { - // No enough data, set length to 0. - length = 0; + // Include the thing after the CR in the rejection exception. + var stopIndex = crIndex + 2; + RejectRequestHeader(span[..stopIndex]); } } - } - // End found in current span - if (length > 0) - { - continue; - } + if (foundCrlf) + { + // Advance past the LF too + reader.Advance(1); - // We moved the reader to look ahead 2 bytes so rewind the reader - if (readAhead > 0) - { - reader.Rewind(readAhead); + // Empty line? + if (crIndex == 0) + { + handler.OnHeadersComplete(endStream: false); + return true; + } + } } + else + { + // We got an LF with no CR before it. + var lfIndex = lfOrCrIndex; + if (_disableHttp1LineFeedTerminators) + { + RejectRequestHeader(AppendEndOfLine(span[..lfIndex], lineFeedOnly: true)); + } - length = ParseMultiSpanHeader(handler, ref reader); + // Consume the header including the LF + reader.Advance(lfIndex + 1); + + span = span.Slice(0, lfIndex); + if (span.Length == 0) + { + handler.OnHeadersComplete(endStream: false); + return true; + } + } + } + else + { + // No CR or LF. Is this a multi-span header? + int length = ParseMultiSpanHeader(handler, ref reader); if (length < 0) { - // Not there + // Not multi-line, just bad. return false; } + // This was a multi-line header. Advance the reader. reader.Advance(length); - // As we crossed spans set the current span to default - // so we move to the next span on the next iteration - span = default; + + continue; + } + + // We got to a point where we believe we have a header. + if (!TryTakeSingleHeader(handler, span)) + { + // Sequence needs to be CRLF and not contain an inner CR not part of terminator. + // Not parsable as a valid name:value header pair. + RejectRequestHeader(AppendEndOfLine(span, lineFeedOnly: !foundCrlf)); } } return false; } + private static byte[] AppendEndOfLine(ReadOnlySpan span, bool lineFeedOnly) + { + var array = new byte[span.Length + (lineFeedOnly ? 1 : 2)]; + + span.CopyTo(array); + array[^1] = ByteLF; + + if (!lineFeedOnly) + { + array[^2] = ByteCR; + } + + return array; + } + + // Parse a header that might cross multiple spans, and return the length of the header + // or -1 if there was a failure during parsing. private int ParseMultiSpanHeader(TRequestHandler handler, ref SequenceReader reader) { var currentSlice = reader.UnreadSequence; @@ -305,45 +324,84 @@ private int ParseMultiSpanHeader(TRequestHandler handler, ref SequenceReader headerSpan; + ReadOnlySequence header; + + var firstLineEndCharPos = lineEndPosition.Value; + currentSlice.TryGet(ref firstLineEndCharPos, out var s); + var firstEolChar = s.Span[0]; + + // Is the first EOL char the last of the current slice? if (currentSlice.Slice(reader.Position, lineEndPosition.Value).Length == currentSlice.Length - 1) { - // No enough data, so CRLF can't currently be there. - // However, we need to check the found char is CR and not LF - - // Advance 1 to include CR/LF in lineEnd - lineEnd = currentSlice.GetPosition(1, lineEndPosition.Value); - var header = currentSlice.Slice(reader.Position, lineEnd); - headerSpan = header.IsSingleSegment ? header.FirstSpan : header.ToArray(); - if (headerSpan[^1] != ByteCR) + // Get the EOL char + if (firstEolChar == ByteCR) + { + // CR without LF, can't read the header + return -1; + } + else { - RejectRequestHeader(headerSpan); + if (_disableHttp1LineFeedTerminators) + { + // LF only but disabled + + // Advance 1 to include LF in result + lineEnd = currentSlice.GetPosition(1, lineEndPosition.Value); + RejectRequestHeader(currentSlice.Slice(reader.Position, lineEnd).ToSpan()); + } } + } + + // At this point the first EOL char is not the last byte in the current slice + + // Offset 1 to include the first EOL char. + firstLineEndCharPos = currentSlice.GetPosition(1, lineEndPosition.Value); + + if (firstEolChar == ByteCR) + { + // First EOL char is CR, include the char after CR + lineEnd = currentSlice.GetPosition(2, lineEndPosition.Value); + header = currentSlice.Slice(reader.Position, lineEnd); + } + else if (_disableHttp1LineFeedTerminators) + { + // The terminator is an LF and we don't allow it. + RejectRequestHeader(currentSlice.Slice(reader.Position, firstLineEndCharPos).ToSpan()); return -1; } + else + { + // First EOL char is LF. only include this one + lineEnd = currentSlice.GetPosition(1, lineEndPosition.Value); + header = currentSlice.Slice(reader.Position, lineEnd); + } - // Advance 2 to include CR{LF?} in lineEnd - lineEnd = currentSlice.GetPosition(2, lineEndPosition.Value); - headerSpan = currentSlice.Slice(reader.Position, lineEnd).ToSpan(); + var headerSpan = header.ToSpan(); - if (headerSpan.Length < 5) + // 'a:b\n' or 'a:b\r\n' + var minHeaderSpan = _disableHttp1LineFeedTerminators ? 5 : 4; + if (headerSpan.Length < minHeaderSpan) { - // Less than min possible headerSpan is 5 bytes a:b\r\n RejectRequestHeader(headerSpan); } - if (headerSpan[^2] != ByteCR) + var terminatorSize = -1; + + if (headerSpan[^1] == ByteLF) { - // Sequence needs to be CRLF not LF first. - RejectRequestHeader(headerSpan[..^1]); + if (headerSpan[^2] == ByteCR) + { + terminatorSize = 2; + } + else if (!_disableHttp1LineFeedTerminators) + { + terminatorSize = 1; + } } - if (headerSpan[^1] != ByteLF || - // Exclude the CRLF from the headerLine and parse the header name:value pair - !TryTakeSingleHeader(handler, headerSpan[..^2])) + // Last chance to bail if the terminator size is not valid or the header doesn't parse. + if (terminatorSize == -1 || !TryTakeSingleHeader(handler, headerSpan.Slice(0, headerSpan.Length - terminatorSize))) { - // Sequence needs to be CRLF and not contain an inner CR not part of terminator. - // Not parsable as a valid name:value header pair. RejectRequestHeader(headerSpan); } @@ -438,7 +496,7 @@ private static bool TryTakeSingleHeader(TRequestHandler handler, ReadOnlySpan(trace.IsEnabled(LogLevel.Information)), + HttpParser = new HttpParser(trace.IsEnabled(LogLevel.Information), serverOptions.DisableHttp1LineFeedTerminators), SystemClock = heartbeatManager, DateHeaderValueManager = dateHeaderValueManager, ConnectionManager = connectionManager, diff --git a/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs index 317edfdfd14a..3dbbdee155d6 100644 --- a/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs +++ b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs @@ -25,6 +25,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core; /// public class KestrelServerOptions { + internal const string DisableHttp1LineFeedTerminatorsSwitchKey = "Microsoft.AspNetCore.Server.Kestrel.DisableHttp1LineFeedTerminators"; + // internal to fast-path header decoding when RequestHeaderEncodingSelector is unchanged. internal static readonly Func DefaultHeaderEncodingSelector = _ => null; @@ -175,6 +177,24 @@ internal bool EnableWebTransportAndH3Datagrams set => _enableWebTransportAndH3Datagrams = value; } + /// + /// Internal AppContext switch to toggle whether a request line can end with LF only instead of CR/LF. + /// + private bool? _disableHttp1LineFeedTerminators; + internal bool DisableHttp1LineFeedTerminators + { + get + { + if (!_disableHttp1LineFeedTerminators.HasValue) + { + _disableHttp1LineFeedTerminators = AppContext.TryGetSwitch(DisableHttp1LineFeedTerminatorsSwitchKey, out var disabled) && disabled; + } + + return _disableHttp1LineFeedTerminators.Value; + } + set => _disableHttp1LineFeedTerminators = value; + } + /// /// Specifies a configuration Action to run for each newly created endpoint. Calling this again will replace /// the prior action. diff --git a/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj b/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj index cf77ba81c063..29990edece34 100644 --- a/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj +++ b/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj @@ -1,4 +1,4 @@ - + Core components of ASP.NET Core Kestrel cross-platform web server. @@ -29,6 +29,7 @@ + diff --git a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs index 2ed4c17d476b..7b21c081fa04 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; diff --git a/src/Servers/Kestrel/Core/test/HttpParserTests.cs b/src/Servers/Kestrel/Core/test/HttpParserTests.cs index e6898927d644..ffbd38b10a0a 100644 --- a/src/Servers/Kestrel/Core/test/HttpParserTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpParserTests.cs @@ -3,9 +3,6 @@ using System; using System.Buffers; -using System.Collections.Generic; -using System.Linq; -using System.Net.Http; using System.Text; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -13,8 +10,6 @@ using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Moq; -using Xunit; using HttpMethod = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; @@ -39,7 +34,7 @@ public void ParsesRequestLine( #pragma warning restore xUnit1026 string expectedVersion) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); @@ -58,7 +53,7 @@ public void ParsesRequestLine( [MemberData(nameof(RequestLineIncompleteData))] public void ParseRequestLineReturnsFalseWhenGivenIncompleteRequestLines(string requestLine) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); @@ -69,7 +64,7 @@ public void ParseRequestLineReturnsFalseWhenGivenIncompleteRequestLines(string r [MemberData(nameof(RequestLineIncompleteData))] public void ParseRequestLineDoesNotConsumeIncompleteRequestLine(string requestLine) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); @@ -87,6 +82,34 @@ public void ParseRequestLineThrowsOnInvalidRequestLine(string requestLine) var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); +#pragma warning disable CS0618 // Type or member is obsolete + var exception = Assert.Throws(() => +#pragma warning restore CS0618 // Type or member is obsolete + ParseRequestLine(parser, requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(requestLine[..^1].EscapeNonPrintable()), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + + [Theory] + [MemberData(nameof(RequestLineInvalidDataLineFeedTerminator))] + public void ParseRequestSucceedsOnInvalidRequestLineLineFeedTerminator(string requestLine) + { + var parser = CreateParser(CreateEnabledTrace(), disableHttp1LineFeedTerminators: false); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + Assert.True(ParseRequestLine(parser, requestHandler, buffer, out var consumed, out var examined)); + } + + [Theory] + [MemberData(nameof(RequestLineInvalidDataLineFeedTerminator))] + public void ParseRequestLineThrowsOnInvalidRequestLineLineFeedTerminator(string requestLine) + { + var parser = CreateParser(CreateEnabledTrace()); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + #pragma warning disable CS0618 // Type or member is obsolete var exception = Assert.Throws(() => #pragma warning restore CS0618 // Type or member is obsolete @@ -102,7 +125,7 @@ public void ParseRequestLineThrowsOnNonTokenCharsInCustomMethod(string method) { var requestLine = $"{method} / HTTP/1.1\r\n"; - var parser = CreateParser(CreateEnabledTrace()); + var parser = CreateParser(CreateEnabledTrace(), false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); @@ -121,7 +144,7 @@ public void ParseRequestLineThrowsOnUnrecognizedHttpVersion(string httpVersion) { var requestLine = $"GET / {httpVersion}\r\n"; - var parser = CreateParser(CreateEnabledTrace()); + var parser = CreateParser(CreateEnabledTrace(), false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); var requestHandler = new RequestHandler(); @@ -134,6 +157,24 @@ public void ParseRequestLineThrowsOnUnrecognizedHttpVersion(string httpVersion) Assert.Equal(StatusCodes.Status505HttpVersionNotsupported, exception.StatusCode); } + [Fact] + public void StartOfPathNotFound() + { + var requestLine = $"GET \n"; + + var parser = CreateParser(CreateEnabledTrace(), false); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + +#pragma warning disable CS0618 // Type or member is obsolete + var exception = Assert.Throws(() => +#pragma warning restore CS0618 // Type or member is obsolete + ParseRequestLine(parser, requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail("GET "), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + [Theory] [InlineData("\r")] [InlineData("H")] @@ -173,7 +214,7 @@ public void ParseRequestLineThrowsOnUnrecognizedHttpVersion(string httpVersion) [InlineData("Header-1: value1\r\nHeader-2: value2\r\n\r")] public void ParseHeadersReturnsFalseWhenGivenIncompleteHeaders(string rawHeaders) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); var requestHandler = new RequestHandler(); @@ -199,7 +240,7 @@ public void ParseHeadersReturnsFalseWhenGivenIncompleteHeaders(string rawHeaders [InlineData("Header: value\r")] public void ParseHeadersDoesNotConsumeIncompleteHeader(string rawHeaders) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); var requestHandler = new RequestHandler(); @@ -224,6 +265,8 @@ public void ParseHeadersCanReadHeaderValueWithoutLeadingWhitespace() [InlineData("Cookie:\r\nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] [InlineData("Connection: close\r\nCookie: \r\n\r\n", "Connection", "close", "Cookie", "")] [InlineData("Connection: close\r\nCookie:\r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("a:b\r\n\r\n", "a", "b", null, null)] + [InlineData("a: b\r\n\r\n", "a", "b", null, null)] public void ParseHeadersCanParseEmptyHeaderValues( string rawHeaders, string expectedHeaderName1, @@ -238,7 +281,116 @@ public void ParseHeadersCanParseEmptyHeaderValues( ? new[] { expectedHeaderValue1 } : new[] { expectedHeaderValue1, expectedHeaderValue2 }; - VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues); + VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues, disableHttp1LineFeedTerminators: false); + } + + [Theory] + [InlineData("Cookie: \n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie:\n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie: \nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie: \r\nConnection: close\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie:\nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie:\r\nConnection: close\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Connection: close\nCookie: \r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\r\nCookie: \n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\nCookie:\r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\r\nCookie:\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("a:b\n\r\n", "a", "b", null, null)] + [InlineData("a: b\n\r\n", "a", "b", null, null)] + [InlineData("a:b\n\n", "a", "b", null, null)] + [InlineData("a: b\n\n", "a", "b", null, null)] + public void ParseHeadersCantParseSingleLineFeedWihtoutLineFeedTerminatorEnabled( + string rawHeaders, + string expectedHeaderName1, + string expectedHeaderValue1, + string expectedHeaderName2, + string expectedHeaderValue2) + { + var expectedHeaderNames = expectedHeaderName2 == null + ? new[] { expectedHeaderName1 } + : new[] { expectedHeaderName1, expectedHeaderName2 }; + var expectedHeaderValues = expectedHeaderValue2 == null + ? new[] { expectedHeaderValue1 } + : new[] { expectedHeaderValue1, expectedHeaderValue2 }; + +#pragma warning disable CS0618 // Type or member is obsolete + Assert.Throws(() => VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues, disableHttp1LineFeedTerminators: true)); +#pragma warning restore CS0618 // Type or member is obsolete + } + + [Theory] + [InlineData("Cookie: \n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie:\n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie: \nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie: \r\nConnection: close\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie:\nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie:\r\nConnection: close\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Connection: close\nCookie: \r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\r\nCookie: \n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\nCookie:\r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\r\nCookie:\n\r\n", "Connection", "close", "Cookie", "")] + public void ParseHeadersCanParseSingleLineFeedWithLineFeedTerminatorEnabled( + string rawHeaders, + string expectedHeaderName1, + string expectedHeaderValue1, + string expectedHeaderName2, + string expectedHeaderValue2) + { + var expectedHeaderNames = expectedHeaderName2 == null + ? new[] { expectedHeaderName1 } + : new[] { expectedHeaderName1, expectedHeaderName2 }; + var expectedHeaderValues = expectedHeaderValue2 == null + ? new[] { expectedHeaderValue1 } + : new[] { expectedHeaderValue1, expectedHeaderValue2 }; + + VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues, disableHttp1LineFeedTerminators: false); + } + + [Theory] + [InlineData("a: b\r\n\n", "a", "b", null, null)] + [InlineData("a: b\n\n", "a", "b", null, null)] + [InlineData("a: b\nc: d\r\n\n", "a", "b", "c", "d")] + [InlineData("a: b\nc: d\n\n", "a", "b", "c", "d")] + public void ParseHeadersCantEndWithLineFeedTerminator( + string rawHeaders, + string expectedHeaderName1, + string expectedHeaderValue1, + string expectedHeaderName2, + string expectedHeaderValue2) + { + var expectedHeaderNames = expectedHeaderName2 == null + ? new[] { expectedHeaderName1 } + : new[] { expectedHeaderName1, expectedHeaderName2 }; + var expectedHeaderValues = expectedHeaderValue2 == null + ? new[] { expectedHeaderValue1 } + : new[] { expectedHeaderValue1, expectedHeaderValue2 }; + +#pragma warning disable CS0618 // Type or member is obsolete + Assert.Throws(() => VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues, disableHttp1LineFeedTerminators: true)); +#pragma warning restore CS0618 // Type or member is obsolete + } + + [Theory] + [InlineData("a:b\n\r\n", "a", "b", null, null)] + [InlineData("a: b\n\r\n", "a", "b", null, null)] + [InlineData("a: b\nc: d\n\r\n", "a", "b", "c", "d")] + [InlineData("a: b\nc: d\n\n", "a", "b", "c", "d")] + [InlineData("a: b\n\n", "a", "b", null, null)] + public void ParseHeadersCanEndAfterLineFeedTerminator( + string rawHeaders, + string expectedHeaderName1, + string expectedHeaderValue1, + string expectedHeaderName2, + string expectedHeaderValue2) + { + var expectedHeaderNames = expectedHeaderName2 == null + ? new[] { expectedHeaderName1 } + : new[] { expectedHeaderName1, expectedHeaderName2 }; + var expectedHeaderValues = expectedHeaderValue2 == null + ? new[] { expectedHeaderValue1 } + : new[] { expectedHeaderValue1, expectedHeaderValue2 }; + + VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues, disableHttp1LineFeedTerminators: false); } [Theory] @@ -289,7 +441,7 @@ public void ParseHeadersPreservesWhitespaceWithinHeaderValue(string headerValue) [Fact] public void ParseHeadersConsumesBytesCorrectlyAtEnd() { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); const string headerLine = "Header: value\r\n\r"; var buffer1 = new ReadOnlySequence(Encoding.ASCII.GetBytes(headerLine)); @@ -312,7 +464,27 @@ public void ParseHeadersConsumesBytesCorrectlyAtEnd() [MemberData(nameof(RequestHeaderInvalidData))] public void ParseHeadersThrowsOnInvalidRequestHeaders(string rawHeaders, string expectedExceptionMessage) { - var parser = CreateParser(CreateEnabledTrace()); + var parser = CreateParser(CreateEnabledTrace(), false); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); + var requestHandler = new RequestHandler(); + +#pragma warning disable CS0618 // Type or member is obsolete + var exception = Assert.Throws(() => +#pragma warning restore CS0618 // Type or member is obsolete + { + var reader = new SequenceReader(buffer); + parser.ParseHeaders(requestHandler, ref reader); + }); + + Assert.Equal(expectedExceptionMessage, exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + + [Theory] + [MemberData(nameof(RequestHeaderInvalidDataLineFeedTerminator))] + public void ParseHeadersThrowsOnInvalidRequestHeadersLineFeedTerminator(string rawHeaders, string expectedExceptionMessage) + { + var parser = CreateParser(CreateEnabledTrace(), true); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); var requestHandler = new RequestHandler(); @@ -374,7 +546,7 @@ public void ExceptionDetailNotIncludedWhenLogLevelInformationNotEnabled() [Fact] public void ParseRequestLineSplitBufferWithoutNewLineDoesNotUpdateConsumed() { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = ReadOnlySequenceFactory.CreateSegments( Encoding.ASCII.GetBytes("GET "), Encoding.ASCII.GetBytes("/")); @@ -390,7 +562,7 @@ public void ParseRequestLineSplitBufferWithoutNewLineDoesNotUpdateConsumed() [Fact] public void ParseRequestLineTlsOverHttp() { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = ReadOnlySequenceFactory.CreateSegments(new byte[] { 0x16, 0x03, 0x01, 0x02, 0x00, 0x01, 0x00, 0xfc, 0x03, 0x03, 0x03, 0xca, 0xe0, 0xfd, 0x0a }); var requestHandler = new RequestHandler(); @@ -410,7 +582,7 @@ public void ParseRequestLineTlsOverHttp() [MemberData(nameof(RequestHeaderInvalidData))] public void ParseHeadersThrowsOnInvalidRequestHeadersWithGratuitouslySplitBuffers(string rawHeaders, string expectedExceptionMessage) { - var parser = CreateParser(CreateEnabledTrace()); + var parser = CreateParser(CreateEnabledTrace(), false); var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent(rawHeaders); var requestHandler = new RequestHandler(); @@ -426,11 +598,33 @@ public void ParseHeadersThrowsOnInvalidRequestHeadersWithGratuitouslySplitBuffer Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } - [Fact] - public void ParseHeadersWithGratuitouslySplitBuffers() + [Theory] + [MemberData(nameof(RequestHeaderInvalidDataLineFeedTerminator))] + public void ParseHeadersThrowsOnInvalidRequestHeadersWithGratuitouslySplitBuffersLineFeedTerminator(string rawHeaders, string expectedExceptionMessage) { - var parser = CreateParser(_nullTrace); - var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent("Host:\r\nConnection: keep-alive\r\n\r\n"); + var parser = CreateParser(CreateEnabledTrace(), true); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent(rawHeaders); + var requestHandler = new RequestHandler(); + +#pragma warning disable CS0618 // Type or member is obsolete + var exception = Assert.Throws(() => +#pragma warning restore CS0618 // Type or member is obsolete + { + var reader = new SequenceReader(buffer); + parser.ParseHeaders(requestHandler, ref reader); + }); + + Assert.Equal(expectedExceptionMessage, exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + + [Theory] + [InlineData("Host:\r\nConnection: keep-alive\r\n\r\n")] + [InlineData("A:B\r\nB: C\r\n\r\n")] + public void ParseHeadersWithGratuitouslySplitBuffers(string headers) + { + var parser = CreateParser(_nullTrace, false); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent(headers); var requestHandler = new RequestHandler(); var reader = new SequenceReader(buffer); @@ -439,11 +633,44 @@ public void ParseHeadersWithGratuitouslySplitBuffers() Assert.True(result); } - [Fact] - public void ParseHeadersWithGratuitouslySplitBuffers2() + [Theory] + [InlineData("Host: \r\nConnection: keep-alive\r")] + public void ParseHeaderLineIncompleteDataWithGratuitouslySplitBuffers(string headers) { - var parser = CreateParser(_nullTrace); - var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent("A:B\r\nB: C\r\n\r\n"); + var parser = CreateParser(_nullTrace, false); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent(headers); + + var requestHandler = new RequestHandler(); + var reader = new SequenceReader(buffer); + var result = parser.ParseHeaders(requestHandler, ref reader); + + Assert.False(result); + } + + [Theory] + [InlineData("Host: \r\nConnection: keep-alive\r")] + public void ParseHeaderLineIncompleteData(string headers) + { + var parser = CreateParser(_nullTrace, false); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(headers)); + + var requestHandler = new RequestHandler(); + var reader = new SequenceReader(buffer); + var result = parser.ParseHeaders(requestHandler, ref reader); + + Assert.False(result); + } + + [Theory] + [InlineData("Host:\nConnection: keep-alive\r\n\r\n")] + [InlineData("Host:\r\nConnection: keep-alive\n\r\n")] + [InlineData("A:B\nB: C\r\n\r\n")] + [InlineData("A:B\r\nB: C\n\r\n")] + [InlineData("Host:\r\nConnection: keep-alive\n\n")] + public void ParseHeadersWithGratuitouslySplitBuffersQuirkMode(string headers) + { + var parser = CreateParser(_nullTrace, disableHttp1LineFeedTerminators: false); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent(headers); var requestHandler = new RequestHandler(); var reader = new SequenceReader(buffer); @@ -474,7 +701,7 @@ private void VerifyHeader( string rawHeaderValue, string expectedHeaderValue) { - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, false); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes($"{headerName}:{rawHeaderValue}\r\n")); var requestHandler = new RequestHandler(); @@ -488,11 +715,11 @@ private void VerifyHeader( Assert.True(buffer.Slice(reader.Position).IsEmpty); } - private void VerifyRawHeaders(string rawHeaders, IEnumerable expectedHeaderNames, IEnumerable expectedHeaderValues) + private void VerifyRawHeaders(string rawHeaders, IEnumerable expectedHeaderNames, IEnumerable expectedHeaderValues, bool disableHttp1LineFeedTerminators = true) { Assert.True(expectedHeaderNames.Count() == expectedHeaderValues.Count(), $"{nameof(expectedHeaderNames)} and {nameof(expectedHeaderValues)} sizes must match"); - var parser = CreateParser(_nullTrace); + var parser = CreateParser(_nullTrace, disableHttp1LineFeedTerminators); var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); var requestHandler = new RequestHandler(); @@ -507,12 +734,14 @@ private void VerifyRawHeaders(string rawHeaders, IEnumerable expectedHea Assert.True(buffer.Slice(reader.Position).IsEmpty); } - private IHttpParser CreateParser(KestrelTrace log) => new HttpParser(log.IsEnabled(LogLevel.Information)); + private IHttpParser CreateParser(KestrelTrace log, bool disableHttp1LineFeedTerminators = true) => new HttpParser(log.IsEnabled(LogLevel.Information), disableHttp1LineFeedTerminators); public static IEnumerable RequestLineValidData => HttpParsingData.RequestLineValidData; public static IEnumerable RequestLineIncompleteData => HttpParsingData.RequestLineIncompleteData.Select(requestLine => new[] { requestLine }); + public static IEnumerable RequestLineInvalidDataLineFeedTerminator => HttpParsingData.RequestLineInvalidDataLineFeedTerminator.Select(requestLine => new[] { requestLine }); + public static IEnumerable RequestLineInvalidData => HttpParsingData.RequestLineInvalidData.Select(requestLine => new[] { requestLine }); public static IEnumerable MethodWithNonTokenCharData => HttpParsingData.MethodWithNonTokenCharData.Select(method => new[] { method }); @@ -521,6 +750,8 @@ private void VerifyRawHeaders(string rawHeaders, IEnumerable expectedHea public static IEnumerable RequestHeaderInvalidData => HttpParsingData.RequestHeaderInvalidData; + public static IEnumerable RequestHeaderInvalidDataLineFeedTerminator => HttpParsingData.RequestHeaderInvalidDataLineFeedTerminator; + private class RequestHandler : IHttpRequestLineHandler, IHttpHeadersHandler { public string Method { get; set; } diff --git a/src/Servers/Kestrel/shared/test/HttpParsingData.cs b/src/Servers/Kestrel/shared/test/HttpParsingData.cs index 8baf18513ced..f259dd9232ee 100644 --- a/src/Servers/Kestrel/shared/test/HttpParsingData.cs +++ b/src/Servers/Kestrel/shared/test/HttpParsingData.cs @@ -205,7 +205,15 @@ public static IEnumerable RequestLineInvalidData "CUSTOM / HTTP/1.1a\n", "CUSTOM / HTTP/1.1a\r\n", "CUSTOM / HTTP/1.1ab\r\n", + "CUSTOM / H\n", + "CUSTOM / HT\n", + "CUSTOM / HTT\n", + "CUSTOM / HTTP\n", + "CUSTOM / HTTP/\n", + "CUSTOM / HTTP/1\n", + "CUSTOM / HTTP/1.\n", "CUSTOM / hello\r\n", + "CUSTOM / hello\n", "CUSTOM ? HTTP/1.1\r\n", "CUSTOM /a?b=cHTTP/1.1\r\n", "CUSTOM /a%20bHTTP/1.1\r\n", @@ -217,6 +225,21 @@ public static IEnumerable RequestLineInvalidData } } + // This list is valid in quirk mode + public static IEnumerable RequestLineInvalidDataLineFeedTerminator + { + get + { + return new[] + { + "GET / HTTP/1.0\n", + "GET / HTTP/1.1\n", + "CUSTOM / HTTP/1.0\n", + "CUSTOM / HTTP/1.1\n", + }; + } + } + // Bad HTTP Methods (invalid according to RFC) public static IEnumerable MethodWithNonTokenCharData { @@ -364,13 +387,19 @@ public static IEnumerable TargetWithNullCharData "8charact", }; - public static IEnumerable RequestHeaderInvalidData => new[] + public static IEnumerable RequestHeaderInvalidDataLineFeedTerminator => new[] { // Missing CR new[] { "Header: value\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header: value\x0A") }, new[] { "Header-1: value1\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: value1\x0A") }, new[] { "Header-1: value1\r\nHeader-2: value2\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2: value2\x0A") }, + // Empty header name + new[] { ":a\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@":a\x0A") }, + }; + + public static IEnumerable RequestHeaderInvalidData => new[] + { // Line folding new[] { "Header: line1\r\n line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" line2\x0D\x0A") }, new[] { "Header: line1\r\n\tline2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09line2\x0D\x0A") }, @@ -404,7 +433,7 @@ public static IEnumerable TargetWithNullCharData new[] { "Header-1 value1\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1 value1\x0D\x0A") }, new[] { "Header-1 value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1 value1\x0D\x0A") }, new[] { "Header-1: value1\r\nHeader-2 value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2 value2\x0D\x0A") }, - new[] { "\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x0A") }, + new[] { "HeaderValue1\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"HeaderValue1\x0D\x0A") }, // Starting with whitespace new[] { " Header: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" Header: value\x0D\x0A") }, @@ -435,11 +464,13 @@ public static IEnumerable TargetWithNullCharData // Headers not ending in CRLF line new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r\r", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, - new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r ", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, + new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r ", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r \n", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, + new[] { "Header-1: value1\r\nHeader-2\t: value2 \n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2\x09: value2 \x0A") }, // Empty header name new[] { ": value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@": value\x0D\x0A") }, + new[] { ":a\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@":a\x0D\x0A") }, }; public static TheoryData HostHeaderData diff --git a/src/Servers/Kestrel/shared/test/TestServiceContext.cs b/src/Servers/Kestrel/shared/test/TestServiceContext.cs index e4fb4e4fd585..119a13020b91 100644 --- a/src/Servers/Kestrel/shared/test/TestServiceContext.cs +++ b/src/Servers/Kestrel/shared/test/TestServiceContext.cs @@ -17,17 +17,17 @@ internal class TestServiceContext : ServiceContext { public TestServiceContext() { - Initialize(NullLoggerFactory.Instance, CreateLoggingTrace(NullLoggerFactory.Instance)); + Initialize(NullLoggerFactory.Instance, CreateLoggingTrace(NullLoggerFactory.Instance), false); } - public TestServiceContext(ILoggerFactory loggerFactory) + public TestServiceContext(ILoggerFactory loggerFactory, bool disableHttp1LineFeedTerminators = true) { - Initialize(loggerFactory, CreateLoggingTrace(loggerFactory)); + Initialize(loggerFactory, CreateLoggingTrace(loggerFactory), disableHttp1LineFeedTerminators); } - public TestServiceContext(ILoggerFactory loggerFactory, KestrelTrace kestrelTrace) + public TestServiceContext(ILoggerFactory loggerFactory, KestrelTrace kestrelTrace, bool disableHttp1LineFeedTerminators = true) { - Initialize(loggerFactory, kestrelTrace); + Initialize(loggerFactory, kestrelTrace, disableHttp1LineFeedTerminators); } private static KestrelTrace CreateLoggingTrace(ILoggerFactory loggerFactory) @@ -49,7 +49,7 @@ public void InitializeHeartbeat() SystemClock = heartbeatManager; } - private void Initialize(ILoggerFactory loggerFactory, KestrelTrace kestrelTrace) + private void Initialize(ILoggerFactory loggerFactory, KestrelTrace kestrelTrace, bool disableHttp1LineFeedTerminators) { LoggerFactory = loggerFactory; Log = kestrelTrace; @@ -58,7 +58,7 @@ private void Initialize(ILoggerFactory loggerFactory, KestrelTrace kestrelTrace) SystemClock = MockSystemClock; DateHeaderValueManager = new DateHeaderValueManager(); ConnectionManager = new ConnectionManager(Log, ResourceCounter.Unlimited); - HttpParser = new HttpParser(Log.IsEnabled(LogLevel.Information)); + HttpParser = new HttpParser(Log.IsEnabled(LogLevel.Information), disableHttp1LineFeedTerminators); ServerOptions = new KestrelServerOptions { AddServerHeader = false diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs index 564ba843842a..5bde02517fff 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs @@ -524,5 +524,7 @@ public static TheoryData InvalidRequestLineData public static IEnumerable InvalidRequestHeaderData => HttpParsingData.RequestHeaderInvalidData; + public static IEnumerable InvalidRequestHeaderDataLineFeedTerminator => HttpParsingData.RequestHeaderInvalidDataLineFeedTerminator; + public static TheoryData InvalidHostHeaderData => HttpParsingData.HostHeaderInvalidData; } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs index 1cb13d90fe68..7c2a98a2f626 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs @@ -2284,6 +2284,49 @@ await connection.Receive( } } + [Fact] + public async Task SingleLineFeedIsSupportedAnywhere() + { + // Exercises all combinations of LF and CRLF as line separators. + // Uses a bit mask for all the possible combinations. + + var lines = new[] + { + $"GET / HTTP/1.1", + "Content-Length: 0", + $"Host: localhost", + "", + }; + + await using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory, disableHttp1LineFeedTerminators: false))) + { + var mask = Math.Pow(2, lines.Length) - 1; + + for (var m = 0; m <= mask; m++) + { + using (var client = server.CreateConnection()) + { + var sb = new StringBuilder(); + + for (var pos = 0; pos < lines.Length; pos++) + { + sb.Append(lines[pos]); + var separator = (m & (1 << pos)) != 0 ? "\n" : "\r\n"; + sb.Append(separator); + } + + var text = sb.ToString(); + var writer = new StreamWriter(client.Stream, Encoding.GetEncoding("iso-8859-1")); + await writer.WriteAsync(text).ConfigureAwait(false); + await writer.FlushAsync().ConfigureAwait(false); + await client.Stream.FlushAsync().ConfigureAwait(false); + + await client.Receive("HTTP/1.1 200"); + } + } + } + } + public static TheoryData HostHeaderData => HttpParsingData.HostHeaderData; private class IntAsClass diff --git a/src/Servers/Kestrel/Core/src/Internal/CancellationTokenSourcePool.cs b/src/Shared/CancellationTokenSourcePool.cs similarity index 96% rename from src/Servers/Kestrel/Core/src/Internal/CancellationTokenSourcePool.cs rename to src/Shared/CancellationTokenSourcePool.cs index 72e6465b7851..94279b1f970f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/CancellationTokenSourcePool.cs +++ b/src/Shared/CancellationTokenSourcePool.cs @@ -3,7 +3,7 @@ using System.Collections.Concurrent; -namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +namespace Microsoft.AspNetCore.Internal; internal sealed class CancellationTokenSourcePool { diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java index 627aa41da114..1f3bfcd59eb1 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java @@ -31,7 +31,12 @@ public InvocationHandler put(String target, Object action, Type... types) { } } } + methodHandlers = new ArrayList<>(methodHandlers); methodHandlers.add(handler); + + // replace List in handlers map + handlers.remove(target); + handlers.put(target, methodHandlers); return handler; } finally { lock.unlock(); @@ -41,7 +46,7 @@ public InvocationHandler put(String target, Object action, Type... types) { public List get(String key) { try { lock.lock(); - return handlers.get(key); + return this.handlers.get(key); } finally { lock.unlock(); } @@ -55,4 +60,21 @@ public void remove(String key) { lock.unlock(); } } + + public void remove(String key, InvocationHandler handler) { + try { + lock.lock(); + List handlers = this.handlers.get(key); + if (handlers != null) { + handlers = new ArrayList<>(handlers); + handlers.remove(handler); + + // replace List in handlers map + this.handlers.remove(key); + this.handlers.put(key, handlers); + } + } finally { + lock.unlock(); + } + } } diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Subscription.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Subscription.java index ab72531472b2..e67e04bd599b 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Subscription.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Subscription.java @@ -23,9 +23,6 @@ public class Subscription { * Removes the client method handler represented by this subscription. */ public void unsubscribe() { - List handler = this.handlers.get(target); - if (handler != null) { - handler.remove(this.handler); - } + this.handlers.remove(this.target, this.handler); } } diff --git a/src/SignalR/clients/ts/FunctionalTests/Startup.cs b/src/SignalR/clients/ts/FunctionalTests/Startup.cs index 1d831c9d28c9..67eb2878ade3 100644 --- a/src/SignalR/clients/ts/FunctionalTests/Startup.cs +++ b/src/SignalR/clients/ts/FunctionalTests/Startup.cs @@ -232,7 +232,7 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILogger< { try { - var result = await hubContext.Clients.Client(id).InvokeAsync("Result"); + var result = await hubContext.Clients.Client(id).InvokeAsync("Result", cancellationToken: default); return result.ToString(CultureInfo.InvariantCulture); } catch (Exception ex) diff --git a/src/SignalR/common/Protocols.Json/src/Microsoft.AspNetCore.SignalR.Protocols.Json.csproj b/src/SignalR/common/Protocols.Json/src/Microsoft.AspNetCore.SignalR.Protocols.Json.csproj index 34a5f17dd19d..38694e211d3f 100644 --- a/src/SignalR/common/Protocols.Json/src/Microsoft.AspNetCore.SignalR.Protocols.Json.csproj +++ b/src/SignalR/common/Protocols.Json/src/Microsoft.AspNetCore.SignalR.Protocols.Json.csproj @@ -17,6 +17,7 @@ + diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index c3930f66dda2..6d726e0c2c17 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -230,8 +230,16 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) else { // If we have an invocation id already we can parse the end result - var returnType = binder.GetReturnType(invocationId); - result = BindType(ref reader, input, returnType); + var returnType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (returnType is null) + { + reader.Skip(); + result = null; + } + else + { + result = BindType(ref reader, input, returnType); + } } } else if (reader.ValueTextEquals(ItemPropertyNameBytes.EncodedUtf8Bytes)) @@ -408,8 +416,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (hasResultToken) { - var returnType = binder.GetReturnType(invocationId); - result = BindType(ref resultToken, input, returnType); + var returnType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (returnType is null) + { + result = null; + } + else + { + result = BindType(ref resultToken, input, returnType); + } } message = BindCompletionMessage(invocationId, error, result, hasResult); diff --git a/src/SignalR/common/Protocols.MessagePack/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack.csproj b/src/SignalR/common/Protocols.MessagePack/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack.csproj index 69554b03f11a..7d3882c020b3 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack.csproj +++ b/src/SignalR/common/Protocols.MessagePack/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack.csproj @@ -12,6 +12,7 @@ + diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs index 8ce10b662bd9..df0dc0c7a8a9 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs @@ -162,14 +162,21 @@ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, error = ReadString(ref reader, "error"); break; case NonVoidResult: - var itemType = binder.GetReturnType(invocationId); - if (itemType == typeof(RawResult)) + var itemType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (itemType is null) { - result = new RawResult(reader.ReadRaw()); + reader.Skip(); } else { - result = DeserializeObject(ref reader, itemType, "argument"); + if (itemType == typeof(RawResult)) + { + result = new RawResult(reader.ReadRaw()); + } + else + { + result = DeserializeObject(ref reader, itemType, "argument"); + } } hasResult = true; break; diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj index 2167aa456d47..f905437355d7 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj @@ -15,6 +15,7 @@ + diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 2df8002c66d9..11dd9d107adb 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -209,21 +209,28 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) else { // If we have an invocation id already we can parse the end result - var returnType = binder.GetReturnType(invocationId); - - if (!JsonUtils.ReadForType(reader, returnType)) - { - throw new JsonReaderException("Unexpected end when reading JSON"); - } - - if (returnType == typeof(RawResult)) + var returnType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (returnType is null) { - var token = JToken.Load(reader); - result = GetRawResult(token); + reader.Skip(); + result = null; } else { - result = PayloadSerializer.Deserialize(reader, returnType); + if (!JsonUtils.ReadForType(reader, returnType)) + { + throw new JsonReaderException("Unexpected end when reading JSON"); + } + + if (returnType == typeof(RawResult)) + { + var token = JToken.Load(reader); + result = GetRawResult(token); + } + else + { + result = PayloadSerializer.Deserialize(reader, returnType); + } } } break; @@ -397,14 +404,21 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (resultToken != null) { - var returnType = binder.GetReturnType(invocationId); - if (returnType == typeof(RawResult)) + var returnType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (returnType is null) { - result = GetRawResult(resultToken); + result = null; } else { - result = resultToken.ToObject(returnType, PayloadSerializer); + if (returnType == typeof(RawResult)) + { + result = GetRawResult(resultToken); + } + else + { + result = resultToken.ToObject(returnType, PayloadSerializer); + } } } diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 97b5d2c7023d..12544fb649dd 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -144,7 +144,7 @@ public void RegisterCancellation() { // TODO: RedisHubLifetimeManager will want to notify the other server (if there is one) about the cancellation // so it can clean up state and potentially forward that info to the connection - _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Canceled")); + _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Invocation canceled by the server.")); } public new void SetResult(T result) diff --git a/src/SignalR/common/Shared/TryGetReturnType.cs b/src/SignalR/common/Shared/TryGetReturnType.cs new file mode 100644 index 000000000000..1cfcd0d189f1 --- /dev/null +++ b/src/SignalR/common/Shared/TryGetReturnType.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +internal static class ProtocolHelper +{ + internal static Type? TryGetReturnType(IInvocationBinder binder, string invocationId) + { + try + { + return binder.GetReturnType(invocationId); + } + // GetReturnType throws if invocationId not found, this can be caused by the server canceling a client-result but the client still sending a result + // For now let's ignore the failure and skip parsing the result, server will log that the result wasn't expected anymore and ignore the message + // In the future we may want a CompletionBindingFailureMessage that we can flow to the dispatcher for handling + catch (Exception) + { + return null; + } + } +} diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs index dcc68f102b14..59e765acf0c4 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs @@ -1,16 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.Collections.Generic; using System.Globalization; -using System.IO; -using System.Linq; using System.Text; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Protocol; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; @@ -460,6 +455,19 @@ public void RawResultRoundTripsProperly(string testDataName) } } + [Fact] + public void UnexpectedClientResultGivesEmptyCompletionMessage() + { + var binder = new TestBinder(); + var message = Frame("{\"type\":3,\"result\":1,\"invocationId\":\"1\"}"); + var data = new ReadOnlySequence(Encoding.UTF8.GetBytes(message)); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out var hubMessage)); + + var completion = Assert.IsType(hubMessage); + Assert.Null(completion.Result); + Assert.Equal("1", completion.InvocationId); + } + public static string Frame(string input) { var data = Encoding.UTF8.GetBytes(input); diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index 978d729cc4d5..647f50cd3324 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -1,15 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Protocol; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; @@ -249,6 +243,19 @@ public void RawResultRoundTripsProperly(string testDataName) } } + [Fact] + public void UnexpectedClientResultGivesEmptyCompletionMessage() + { + var binder = new TestBinder(); + var input = Frame(Convert.FromBase64String("lQOAo3h5egPA")); + var data = new ReadOnlySequence(input); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out var hubMessage)); + + var completion = Assert.IsType(hubMessage); + Assert.Null(completion.Result); + Assert.Equal("xyz", completion.InvocationId); + } + public class ClientResultTestData { public string Name { get; } diff --git a/src/SignalR/server/Core/src/ClientProxyExtensions.cs b/src/SignalR/server/Core/src/ClientProxyExtensions.cs index 7b2f0c8c8b8f..c7716ef8af51 100644 --- a/src/SignalR/server/Core/src/ClientProxyExtensions.cs +++ b/src/SignalR/server/Core/src/ClientProxyExtensions.cs @@ -227,7 +227,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, Array.Empty(), cancellationToken); } @@ -241,7 +241,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1 }, cancellationToken); } @@ -256,7 +256,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2 }, cancellationToken); } @@ -272,7 +272,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3 }, cancellationToken); } @@ -289,7 +289,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4 }, cancellationToken); } @@ -307,7 +307,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5 }, cancellationToken); } @@ -326,7 +326,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6 }, cancellationToken); } @@ -346,7 +346,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7 }, cancellationToken); } @@ -367,7 +367,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 }, cancellationToken); } @@ -389,7 +389,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9 }, cancellationToken); } @@ -412,7 +412,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] - public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, CancellationToken cancellationToken = default) + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, CancellationToken cancellationToken) { return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); } diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 45509fee030b..230af87fc33e 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -326,7 +326,7 @@ public override Task SendUsersAsync(IReadOnlyList userIds, string method } /// - public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken) { if (connectionId == null) { @@ -341,6 +341,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri } var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, connection.ConnectionAborted, out var linkedToken); var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 04e211d74f86..5ac7769617ce 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -79,12 +79,8 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo _systemClock = contextOptions.SystemClock ?? new SystemClock(); _lastSendTick = _systemClock.CurrentTicks; - // We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it var maxInvokeLimit = contextOptions.MaximumParallelInvocations; - if (maxInvokeLimit != 1) - { - ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit); - } + ActiveInvocationLimit = new ChannelBasedSemaphore(maxInvokeLimit); } internal StreamTracker StreamTracker @@ -102,11 +98,10 @@ internal StreamTracker StreamTracker } internal HubCallerContext HubCallerContext { get; } - internal HubCallerClients HubCallerClients { get; set; } = null!; internal Exception? CloseException { get; private set; } - internal SemaphoreSlim? ActiveInvocationLimit { get; } + internal ChannelBasedSemaphore ActiveInvocationLimit { get; } /// /// Gets a that notifies when the connection is aborted. diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 3bb5566e6a7a..3269d6f7afcd 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -200,6 +200,9 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep // Ensure the connection is aborted before firing disconnect await connection.AbortAsync(); + // If a client result is requested in OnDisconnectedAsync we want to avoid the SemaphoreFullException and get the better connection disconnected IOException + _ = connection.ActiveInvocationLimit.TryAcquire(); + try { await _dispatcher.OnDisconnectedAsync(connection, exception); diff --git a/src/SignalR/server/Core/src/HubLifetimeManager.cs b/src/SignalR/server/Core/src/HubLifetimeManager.cs index 14a294190876..f1bc8b058074 100644 --- a/src/SignalR/server/Core/src/HubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/HubLifetimeManager.cs @@ -142,9 +142,9 @@ public abstract class HubLifetimeManager where THub : Hub /// The connection ID. /// The invocation method name. /// The invocation arguments. - /// The token to monitor for cancellation requests. The default value is . + /// The token to monitor for cancellation requests. It is recommended to set a max wait for expecting a result. /// The response from the connection. - public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken) { throw new NotImplementedException($"{GetType().Name} does not support client return values."); } diff --git a/src/SignalR/server/Core/src/ISingleClientProxy.cs b/src/SignalR/server/Core/src/ISingleClientProxy.cs index f400b13e6acc..9a4451e3810e 100644 --- a/src/SignalR/server/Core/src/ISingleClientProxy.cs +++ b/src/SignalR/server/Core/src/ISingleClientProxy.cs @@ -18,7 +18,7 @@ public interface ISingleClientProxy : IClientProxy /// /// Name of the method to invoke. /// A collection of arguments to pass to the client. - /// The token to monitor for cancellation requests. The default value is . + /// The token to monitor for cancellation requests. It is recommended to set a max wait for expecting a result. /// A that represents the asynchronous invoke and wait for a client result. - Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default); + Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken); } diff --git a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs new file mode 100644 index 000000000000..8b6bbbe0ec6f --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Threading.Channels; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +// Use a Channel instead of a SemaphoreSlim so that we can potentially save task allocations (ValueTask!) +// Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim +internal sealed class ChannelBasedSemaphore +{ + private readonly Channel _channel; + + public ChannelBasedSemaphore(int maxCapacity) + { + _channel = Channel.CreateBounded(maxCapacity); + for (var i = 0; i < maxCapacity; i++) + { + _channel.Writer.TryWrite(1); + } + } + + public bool TryAcquire() + { + return _channel.Reader.TryRead(out _); + } + + // The int result isn't important, only reason it's exposed is because ValueTask doesn't implement ValueTask so we can't cast like we could with Task to Task + public ValueTask WaitAsync(CancellationToken cancellationToken = default) + { + return _channel.Reader.ReadAsync(cancellationToken); + } + + public void Release() + { + if (!_channel.Writer.TryWrite(1)) + { + throw new SemaphoreFullException(); + } + } + + public ValueTask RunAsync(Func> callback, TState state) + { + if (TryAcquire()) + { + _ = RunTask(callback, state); + return ValueTask.CompletedTask; + } + + return RunSlowAsync(callback, state); + } + + private async ValueTask RunSlowAsync(Func> callback, TState state) + { + _ = await WaitAsync(); + _ = RunTask(callback, state); + } + + private async Task RunTask(Func> callback, TState state) + { + try + { + var shouldRelease = await callback(state); + if (shouldRelease) + { + Release(); + } + } + catch + { + // DefaultHubDispatcher catches and handles exceptions + // It does write to the connection in exception cases which also can't throw because we catch and log in HubConnectionContext + Debug.Assert(false); + } + } +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 59b06dbf101a..30c06e594c7c 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -73,13 +73,13 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { - InitializeHub(hub, connection); + // OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync) + InitializeHub(hub, connection, invokeAllowed: false); if (_onConnectedMiddleware != null) { @@ -90,9 +90,6 @@ public override async Task OnConnectedAsync(HubConnectionContext connection) { await hub.OnConnectedAsync(); } - - // OnConnectedAsync is finished, allow hub methods to use client results (ISingleClientProxy.InvokeAsync) - connection.HubCallerClients.InvokeAllowed = true; } finally { @@ -256,13 +253,13 @@ private Task ProcessInvocation(HubConnectionContext connection, else { bool isStreamCall = descriptor.StreamingParameters != null; - if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + if (!isStreamCall && !isStreamResponse) { return connection.ActiveInvocationLimit.RunAsync(static state => { var (dispatcher, descriptor, connection, invocationMessage) = state; return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); - }, (this, descriptor, connection, hubMethodInvocationMessage)); + }, (this, descriptor, connection, hubMethodInvocationMessage)).AsTask(); } else { @@ -271,11 +268,12 @@ private Task ProcessInvocation(HubConnectionContext connection, } } - private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, + private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall) { var methodExecutor = descriptor.MethodExecutor; + var wasSemaphoreReleased = false; var disposeScope = true; var scope = _serviceScopeFactory.CreateAsyncScope(); IHubActivator? hubActivator = null; @@ -290,12 +288,12 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext c Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); - return; + return true; } if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection)) { - return; + return true; } try @@ -308,7 +306,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); - return; + return true; } InitializeHub(hub, connection); @@ -404,9 +402,15 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { if (disposeScope) { + if (hub?.Clients is HubCallerClients hubCallerClients) + { + wasSemaphoreReleased = Interlocked.CompareExchange(ref hubCallerClients.ShouldReleaseSemaphore, 0, 1) == 0; + } await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); } } + + return !wasSemaphoreReleased; } private static ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator? hubActivator, @@ -553,9 +557,9 @@ private static async Task SendInvocationError(string? invocationId, await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); } - private void InitializeHub(THub hub, HubConnectionContext connection) + private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true) { - hub.Clients = connection.HubCallerClients; + hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit) { InvokeAllowed = invokeAllowed }; hub.Context = connection.HubCallerContext; hub.Groups = _hubContext.Groups; } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index e2a65ca7d1d9..8e6ec0fa0dc9 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -7,20 +7,20 @@ internal sealed class HubCallerClients : IHubCallerClients { private readonly string _connectionId; private readonly IHubClients _hubClients; - private readonly string[] _currentConnectionId; - private readonly bool _parallelEnabled; + internal readonly ChannelBasedSemaphore _parallelInvokes; + + internal int ShouldReleaseSemaphore = 1; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called // so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance internal bool InvokeAllowed { get; set; } - public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled) + public HubCallerClients(IHubClients hubClients, string connectionId, ChannelBasedSemaphore parallelInvokes) { _connectionId = connectionId; _hubClients = hubClients; - _currentConnectionId = new[] { _connectionId }; - _parallelEnabled = parallelEnabled; + _parallelInvokes = parallelInvokes; } IClientProxy IHubCallerClients.Caller => Caller; @@ -28,19 +28,15 @@ public ISingleClientProxy Caller { get { - if (!_parallelEnabled) - { - return new NotParallelSingleClientProxy(_hubClients.Client(_connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(_connectionId); + return new SingleClientProxy(_hubClients.Client(_connectionId), this); } } - public IClientProxy Others => _hubClients.AllExcept(_currentConnectionId); + public IClientProxy Others => _hubClients.AllExcept(new[] { _connectionId }); public IClientProxy All => _hubClients.All; @@ -52,15 +48,11 @@ public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) IClientProxy IHubClients.Client(string connectionId) => Client(connectionId); public ISingleClientProxy Client(string connectionId) { - if (!_parallelEnabled) - { - return new NotParallelSingleClientProxy(_hubClients.Client(connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(connectionId); + return new SingleClientProxy(_hubClients.Client(connectionId), this); } public IClientProxy Group(string groupName) @@ -75,7 +67,7 @@ public IClientProxy Groups(IReadOnlyList groupNames) public IClientProxy OthersInGroup(string groupName) { - return _hubClients.GroupExcept(groupName, _currentConnectionId); + return _hubClients.GroupExcept(groupName, new[] { _connectionId }); } public IClientProxy GroupExcept(string groupName, IReadOnlyList excludedConnectionIds) @@ -98,18 +90,18 @@ public IClientProxy Users(IReadOnlyList userIds) return _hubClients.Users(userIds); } - private sealed class NotParallelSingleClientProxy : ISingleClientProxy + private sealed class NoInvokeSingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; - public NotParallelSingleClientProxy(ISingleClientProxy hubClients) + public NoInvokeSingleClientProxy(ISingleClientProxy hubClients) { _proxy = hubClients; } public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { - throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1."); + throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed."); } public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) @@ -118,18 +110,29 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance } } - private sealed class NoInvokeSingleClientProxy : ISingleClientProxy + private sealed class SingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; + private readonly HubCallerClients _hubCallerClients; - public NoInvokeSingleClientProxy(ISingleClientProxy hubClients) + public SingleClientProxy(ISingleClientProxy hubClients, HubCallerClients hubCallerClients) { _proxy = hubClients; + _hubCallerClients = hubCallerClients; } - public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + public async Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { - throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed."); + // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. + // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever + var value = Interlocked.CompareExchange(ref _hubCallerClients.ShouldReleaseSemaphore, 0, 1); + // Only release once, and we set ShouldReleaseSemaphore to 0 so the DefaultHubDispatcher knows not to call Release again + if (value == 1) + { + _hubCallerClients._parallelInvokes.Release(); + } + var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); + return result; } public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) diff --git a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs index 45d80673825e..9b844929b951 100644 --- a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs +++ b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs @@ -27,6 +27,7 @@ public Type GetReturnType(string invocationId) { return type; } + // If the id isn't found then it's possible the server canceled the request for a result but the client still sent the result. throw new InvalidOperationException($"Unknown invocation ID '{invocationId}'."); } diff --git a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs b/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs deleted file mode 100644 index a238d09643e3..000000000000 --- a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.SignalR.Internal; - -internal static class SemaphoreSlimExtensions -{ - public static Task RunAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - if (semaphoreSlim.Wait(0)) - { - _ = RunTask(callback, semaphoreSlim, state); - return Task.CompletedTask; - } - - return RunSlowAsync(semaphoreSlim, callback, state); - } - - private static async Task RunSlowAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - await semaphoreSlim.WaitAsync(); - return RunTask(callback, semaphoreSlim, state); - } - - static async Task RunTask(Func callback, SemaphoreSlim semaphoreSlim, TState state) - { - try - { - await callback(state); - } - finally - { - semaphoreSlim.Release(); - } - } -} diff --git a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt index 860477539be5..f7dd250a0ce1 100644 --- a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt @@ -5,21 +5,21 @@ Microsoft.AspNetCore.SignalR.IHubCallerClients.Caller.get -> Microsoft.AspNetCor Microsoft.AspNetCore.SignalR.IHubCallerClients.Client(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! Microsoft.AspNetCore.SignalR.IHubClients.Client(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! Microsoft.AspNetCore.SignalR.ISingleClientProxy -Microsoft.AspNetCore.SignalR.ISingleClientProxy.InvokeCoreAsync(string! method, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Microsoft.AspNetCore.SignalR.ISingleClientProxy.InvokeCoreAsync(string! method, object?[]! args, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/SignalR/test/ClientProxyTests.cs b/src/SignalR/server/SignalR/test/ClientProxyTests.cs index 784fc98bf01f..708935fc8e6d 100644 --- a/src/SignalR/server/SignalR/test/ClientProxyTests.cs +++ b/src/SignalR/server/SignalR/test/ClientProxyTests.cs @@ -213,7 +213,7 @@ public async Task SingleClientProxyWithInvoke_ThrowsNotSupported() var hubLifetimeManager = new EmptyHubLifetimeManager(); var proxy = new SingleClientProxy(hubLifetimeManager, ""); - var ex = await Assert.ThrowsAsync(async () => await proxy.InvokeAsync("method")).DefaultTimeout(); + var ex = await Assert.ThrowsAsync(async () => await proxy.InvokeAsync("method", cancellationToken: default)).DefaultTimeout(); Assert.Equal("EmptyHubLifetimeManager`1 does not support client return values.", ex.Message); } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index ca84bad5cafc..dc4ad919292e 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -335,9 +335,27 @@ public async Task BlockingMethod() public async Task GetClientResult(int num) { - var sum = await Clients.Caller.InvokeAsync("Sum", num); + var sum = await Clients.Caller.InvokeAsync("Sum", num, cancellationToken: default); return sum; } + + public void BackgroundClientResult(TcsService tcsService) + { + var caller = Clients.Caller; + _ = Task.Run(async () => + { + try + { + await tcsService.StartedMethod.Task; + var result = await caller.InvokeAsync("GetResult", 1, CancellationToken.None); + tcsService.EndMethod.SetResult(result); + } + catch (Exception ex) + { + tcsService.EndMethod.SetException(ex); + } + }); + } } internal class SelfRef @@ -537,6 +555,8 @@ public interface ITest Task Broadcast(string message); Task GetClientResult(int value); + + Task GetClientResultWithCancellation(int value, CancellationToken cancellationToken); } public record ClientResults(int ClientResult, int CallerResult); @@ -1258,7 +1278,7 @@ public class OnConnectedClientResultHub : Hub { public override async Task OnConnectedAsync() { - await Clients.Caller.InvokeAsync("Test"); + await Clients.Caller.InvokeAsync("Test", cancellationToken: default); } } @@ -1266,7 +1286,7 @@ public class OnDisconnectedClientResultHub : Hub { public override async Task OnDisconnectedAsync(Exception ex) { - await Clients.Caller.InvokeAsync("Test"); + await Clients.Caller.InvokeAsync("Test", cancellationToken: default); } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 91133c96048d..1320c674d622 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -14,11 +14,7 @@ public async Task CanReturnClientResultToHub() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) @@ -47,10 +43,8 @@ public async Task CanReturnClientResultErrorToHub() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -74,36 +68,6 @@ public async Task CanReturnClientResultErrorToHub() } } - [Fact] - public async Task ThrowsWhenParallelHubInvokesNotEnabled() - { - using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) - { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - builder.AddSignalR(o => - { - o.MaximumParallelInvocationsPerClient = 1; - o.EnableDetailedErrors = true; - }); - }, LoggerFactory); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - - var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); - - // Hub asks client for a result, this is an invocation message with an ID - var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); - Assert.Equal(invocationId, completionMessage.InvocationId); - Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. InvalidOperationException: Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.", - completionMessage.Error); - } - } - } - [Fact] public async Task ThrowsWhenUsedInOnConnectedAsync() { @@ -113,7 +77,6 @@ public async Task ThrowsWhenUsedInOnConnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -141,7 +104,6 @@ public async Task ThrowsWhenUsedInOnDisconnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -180,7 +142,7 @@ public async Task CanUseClientResultsWithIHubContext() await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); var context = serviceProvider.GetRequiredService>(); - var resultTask = context.Clients.Client(client.Connection.ConnectionId).InvokeAsync("GetClientResult", 1); + var resultTask = context.Clients.Client(client.Connection.ConnectionId).InvokeAsync("GetClientResult", 1, cancellationToken: default); var message = await client.ReadAsync().DefaultTimeout(); var invocation = Assert.IsType(message); @@ -235,14 +197,10 @@ public async Task CanReturnClientResultToTypedHubTwoWays() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); - using var client = new TestClient(invocationBinder: new GetClientResultThreeWaysInvocationBinder()); + using var client = new TestClient(invocationBinder: new GetClientResultTwoWaysInvocationBinder()); var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); @@ -266,7 +224,223 @@ public async Task CanReturnClientResultToTypedHubTwoWays() } } - private class GetClientResultThreeWaysInvocationBinder : IInvocationBinder + [Fact] + public async Task ClientResultFromHubDoesNotBlockReceiveLoop() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + // block 1 of the 2 parallel invocations + _ = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.BlockingMethod), Array.Empty())).DefaultTimeout(); + + // make multiple invocations which would normally block the invocation processing + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("2", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId2 = await client.SendHubMessageAsync(new InvocationMessage("3", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId3 = await client.SendHubMessageAsync(new InvocationMessage("4", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Read all 3 invocation messages from the server, shows that the hub processing continued even though parallel invokes is 2 + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage2 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage3 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + + Assert.NotNull(invocationMessage.InvocationId); + Assert.NotNull(invocationMessage2.InvocationId); + Assert.NotNull(invocationMessage3.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(9L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + + res = 5 + ((long)invocationMessage2.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage2.InvocationId, res)).DefaultTimeout(); + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(10L, completion.Result); + Assert.Equal(invocationId2, completion.InvocationId); + + res = 6 + ((long)invocationMessage3.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage3.InvocationId, res)).DefaultTimeout(); + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(11L, completion.Result); + Assert.Equal(invocationId3, completion.InvocationId); + } + } + } + + [Fact] + public async Task ClientResultFromBackgroundThreadInHubMethodWorks() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var completionMessage = await client.InvokeAsync(nameof(MethodHub.BackgroundClientResult)).DefaultTimeout(); + + tcsService.StartedMethod.SetResult(null); + + var task = await Task.WhenAny(tcsService.EndMethod.Task, client.ReadAsync()).DefaultTimeout(); + if (task == tcsService.EndMethod.Task) + { + await tcsService.EndMethod.Task; + } + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await (Task)task); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + Assert.Equal(5, await tcsService.EndMethod.Task.DefaultTimeout()); + + // Make sure we can still do a Hub invocation and that the semaphore state didn't get messed up + completionMessage = await client.InvokeAsync(nameof(MethodHub.ValueMethod)).DefaultTimeout(); + Assert.Equal(43L, completionMessage.Result); + } + } + } + + private class TestBinder : IInvocationBinder + { + public IReadOnlyList GetParameterTypes(string methodName) + { + return new Type[] { typeof(int) }; + } + + public Type GetReturnType(string invocationId) + { + return typeof(string); + } + + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } + } + + [Theory] + [InlineData("MessagePack")] + [InlineData("Json")] + public async Task CanCancelClientResultsWithIHubContextT(string protocol) + { + IHubProtocol hubProtocol; + if (string.Equals(protocol, "MessagePack")) + { + hubProtocol = new MessagePackHubProtocol(); + } + else if (string.Equals(protocol, "Json")) + { + hubProtocol = new JsonHubProtocol(); + } + else + { + throw new Exception($"Protocol {protocol} not handled by test."); + } + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using var client = new TestClient(hubProtocol, new TestBinder()); + var connectionId = client.Connection.ConnectionId; + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); + + var context = serviceProvider.GetRequiredService>(); + + var cts = new CancellationTokenSource(); + var resultTask = context.Clients.Client(connectionId).GetClientResultWithCancellation(1, cts.Token); + + var message = await client.ReadAsync().DefaultTimeout(); + var invocation = Assert.IsType(message); + + Assert.Single(invocation.Arguments); + Assert.Equal(1, invocation.Arguments[0]); + Assert.Equal("GetClientResultWithCancellation", invocation.Target); + + cts.Cancel(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Invocation canceled by the server.", ex.Message); + + // Sending result after the server is no longer expecting one results in a log and no-ops + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); + + // Send another message from the client and get a result back to make sure the connection is still active. + // Regression test for when sending a client result after it was canceled would close the connection + var completion = await client.InvokeAsync(nameof(HubT.Echo), "test").DefaultTimeout(); + Assert.Equal("test", completion.Result); + + Assert.Contains(TestSink.Writes, c => c.EventId.Name == "UnexpectedCompletion"); + } + } + + [Fact] + public async Task CanCancelClientResultsWithIHubContext() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using var client = new TestClient(); + var connectionId = client.Connection.ConnectionId; + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); + + var context = serviceProvider.GetRequiredService>(); + + var cts = new CancellationTokenSource(); + var resultTask = context.Clients.Client(connectionId).InvokeAsync(nameof(MethodHub.GetClientResult), 1, cts.Token); + + var message = await client.ReadAsync().DefaultTimeout(); + var invocation = Assert.IsType(message); + + Assert.Single(invocation.Arguments); + Assert.Equal(1L, invocation.Arguments[0]); + Assert.Equal("GetClientResult", invocation.Target); + + cts.Cancel(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Invocation canceled by the server.", ex.Message); + + // Sending result after the server is no longer expecting one results in a log and no-ops + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); + + // Send another message from the client and get a result back to make sure the connection is still active. + // Regression test for when sending a client result after it was canceled would close the connection + var completion = await client.InvokeAsync("Echo", "test").DefaultTimeout(); + Assert.Equal("test", completion.Result); + + Assert.Contains(TestSink.Writes, c => c.EventId.Name == "UnexpectedCompletion"); + } + } + + private class GetClientResultTwoWaysInvocationBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) => new[] { typeof(int) }; public Type GetReturnType(string invocationId) => typeof(ClientResults); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 8b35d4647590..32af4cfe6246 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2991,16 +2991,23 @@ public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout() await client.SendHubMessageAsync(PingMessage.Instance); // Call long running hub method - var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); + var hubMethodTask1 = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); await tcsService.StartedMethod.Task.DefaultTimeout(); + // Wait for server to start reading again + await customDuplex.WrappedPipeReader.WaitForReadStart().DefaultTimeout(); + // Send another invocation to server, since we use Inline scheduling we know that once this call completes the server will have read and processed + // the message, it should be stuck waiting for the in-progress invoke now + _ = await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod)).DefaultTimeout(); + // Tick heartbeat while hub method is running to show that close isn't triggered client.TickHeartbeat(); // Unblock long running hub method tcsService.EndMethod.SetResult(null); - await hubMethodTask.DefaultTimeout(); + await hubMethodTask1.DefaultTimeout(); + await client.ReadAsync().DefaultTimeout(); // There is a small window when the hub method finishes and the timer starts again // So we need to delay a little before ticking the heart beat. diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index 19b061d1c6bd..426a06ddeebc 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -186,7 +186,7 @@ public async Task CanProcessClientReturnResult() await manager.OnConnectedAsync(connection1).DefaultTimeout(); - var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); @@ -213,7 +213,7 @@ public async Task CanProcessClientReturnErrorResult() await manager.OnConnectedAsync(connection1).DefaultTimeout(); - var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); @@ -243,7 +243,7 @@ public async Task ExceptionWhenIncorrectClientCompletesClientResult() await manager.OnConnectedAsync(connection1).DefaultTimeout(); await manager.OnConnectedAsync(connection2).DefaultTimeout(); - var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); @@ -277,7 +277,7 @@ public async Task ConnectionIDNotPresentWhenInvokingClientResult() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // No client with this ID - await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" }, cancellationToken: default)).DefaultTimeout(); } } @@ -299,8 +299,8 @@ public async Task InvokesForMultipleClientsDoNotCollide() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); await manager1.OnConnectedAsync(connection2).DefaultTimeout(); - var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); - var invoke2 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); + var invoke2 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); @@ -329,7 +329,7 @@ public async Task ClientDisconnectsWithoutCompletingClientResult() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); - var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); connection1.Abort(); await manager1.OnDisconnectedAsync(connection1).DefaultTimeout(); diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 5fc1c9637c76..fcaad6c86345 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -482,7 +482,7 @@ public async Task CanProcessClientReturnResultAcrossServers() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // Server2 asks for a result from client1 on Server1 - var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); @@ -513,7 +513,7 @@ public async Task CanProcessClientReturnErrorResultAcrossServers() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // Server2 asks for a result from client1 on Server1 - var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); @@ -544,7 +544,7 @@ public async Task ConnectionIDNotPresentMultiServerWhenInvokingClientResult() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // No client on any backplanes with this ID - await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" }, cancellationToken: default)).DefaultTimeout(); } } @@ -565,7 +565,7 @@ public async Task ClientDisconnectsWithoutCompletingClientResultOnSecondServer() await manager2.OnConnectedAsync(connection1).DefaultTimeout(); - var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); connection1.Abort(); @@ -597,10 +597,10 @@ public async Task InvocationsFromDifferentServersUseUniqueIDs() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); await manager2.OnConnectedAsync(connection2).DefaultTimeout(); - var invoke1 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + var invoke1 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); - var invoke2 = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invoke2 = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId); @@ -626,7 +626,7 @@ public async Task ConnectionDoesNotExist_FailsInvokeConnectionAsync() var manager1 = CreateNewHubLifetimeManager(backplane); var manager2 = CreateNewHubLifetimeManager(backplane); - var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" })).DefaultTimeout(); + var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" }, cancellationToken: default)).DefaultTimeout(); Assert.Equal("Connection '1234' does not exist.", ex.Message); } } diff --git a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt index 3a2eac25c41d..09d6bc878431 100644 --- a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt @@ -1,4 +1,4 @@ #nullable enable -override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index eb3bbab0cfc3..1d8397272ab4 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -406,7 +406,7 @@ public void Dispose() } /// - public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken) { // send thing if (connectionId == null) @@ -428,8 +428,8 @@ public override async Task InvokeConnectionAsync(string connectionId, stri if (connection == null) { // TODO: Need to handle other server going away while waiting for connection result - var m = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName)); - var received = await PublishAsync(_channels.Connection(connectionId), m); + var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName)); + var received = await PublishAsync(_channels.Connection(connectionId), messageBytes); if (received < 1) { throw new IOException($"Connection '{connectionId}' does not exist.");