diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 04e211d74f86..5ac7769617ce 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -79,12 +79,8 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo _systemClock = contextOptions.SystemClock ?? new SystemClock(); _lastSendTick = _systemClock.CurrentTicks; - // We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it var maxInvokeLimit = contextOptions.MaximumParallelInvocations; - if (maxInvokeLimit != 1) - { - ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit); - } + ActiveInvocationLimit = new ChannelBasedSemaphore(maxInvokeLimit); } internal StreamTracker StreamTracker @@ -102,11 +98,10 @@ internal StreamTracker StreamTracker } internal HubCallerContext HubCallerContext { get; } - internal HubCallerClients HubCallerClients { get; set; } = null!; internal Exception? CloseException { get; private set; } - internal SemaphoreSlim? ActiveInvocationLimit { get; } + internal ChannelBasedSemaphore ActiveInvocationLimit { get; } /// /// Gets a that notifies when the connection is aborted. diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 3bb5566e6a7a..3269d6f7afcd 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -200,6 +200,9 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep // Ensure the connection is aborted before firing disconnect await connection.AbortAsync(); + // If a client result is requested in OnDisconnectedAsync we want to avoid the SemaphoreFullException and get the better connection disconnected IOException + _ = connection.ActiveInvocationLimit.TryAcquire(); + try { await _dispatcher.OnDisconnectedAsync(connection, exception); diff --git a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs new file mode 100644 index 000000000000..8b6bbbe0ec6f --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Threading.Channels; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +// Use a Channel instead of a SemaphoreSlim so that we can potentially save task allocations (ValueTask!) +// Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim +internal sealed class ChannelBasedSemaphore +{ + private readonly Channel _channel; + + public ChannelBasedSemaphore(int maxCapacity) + { + _channel = Channel.CreateBounded(maxCapacity); + for (var i = 0; i < maxCapacity; i++) + { + _channel.Writer.TryWrite(1); + } + } + + public bool TryAcquire() + { + return _channel.Reader.TryRead(out _); + } + + // The int result isn't important, only reason it's exposed is because ValueTask doesn't implement ValueTask so we can't cast like we could with Task to Task + public ValueTask WaitAsync(CancellationToken cancellationToken = default) + { + return _channel.Reader.ReadAsync(cancellationToken); + } + + public void Release() + { + if (!_channel.Writer.TryWrite(1)) + { + throw new SemaphoreFullException(); + } + } + + public ValueTask RunAsync(Func> callback, TState state) + { + if (TryAcquire()) + { + _ = RunTask(callback, state); + return ValueTask.CompletedTask; + } + + return RunSlowAsync(callback, state); + } + + private async ValueTask RunSlowAsync(Func> callback, TState state) + { + _ = await WaitAsync(); + _ = RunTask(callback, state); + } + + private async Task RunTask(Func> callback, TState state) + { + try + { + var shouldRelease = await callback(state); + if (shouldRelease) + { + Release(); + } + } + catch + { + // DefaultHubDispatcher catches and handles exceptions + // It does write to the connection in exception cases which also can't throw because we catch and log in HubConnectionContext + Debug.Assert(false); + } + } +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 59b06dbf101a..30c06e594c7c 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -73,13 +73,13 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { - InitializeHub(hub, connection); + // OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync) + InitializeHub(hub, connection, invokeAllowed: false); if (_onConnectedMiddleware != null) { @@ -90,9 +90,6 @@ public override async Task OnConnectedAsync(HubConnectionContext connection) { await hub.OnConnectedAsync(); } - - // OnConnectedAsync is finished, allow hub methods to use client results (ISingleClientProxy.InvokeAsync) - connection.HubCallerClients.InvokeAllowed = true; } finally { @@ -256,13 +253,13 @@ private Task ProcessInvocation(HubConnectionContext connection, else { bool isStreamCall = descriptor.StreamingParameters != null; - if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + if (!isStreamCall && !isStreamResponse) { return connection.ActiveInvocationLimit.RunAsync(static state => { var (dispatcher, descriptor, connection, invocationMessage) = state; return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); - }, (this, descriptor, connection, hubMethodInvocationMessage)); + }, (this, descriptor, connection, hubMethodInvocationMessage)).AsTask(); } else { @@ -271,11 +268,12 @@ private Task ProcessInvocation(HubConnectionContext connection, } } - private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, + private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall) { var methodExecutor = descriptor.MethodExecutor; + var wasSemaphoreReleased = false; var disposeScope = true; var scope = _serviceScopeFactory.CreateAsyncScope(); IHubActivator? hubActivator = null; @@ -290,12 +288,12 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext c Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); - return; + return true; } if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection)) { - return; + return true; } try @@ -308,7 +306,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); - return; + return true; } InitializeHub(hub, connection); @@ -404,9 +402,15 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { if (disposeScope) { + if (hub?.Clients is HubCallerClients hubCallerClients) + { + wasSemaphoreReleased = Interlocked.CompareExchange(ref hubCallerClients.ShouldReleaseSemaphore, 0, 1) == 0; + } await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); } } + + return !wasSemaphoreReleased; } private static ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator? hubActivator, @@ -553,9 +557,9 @@ private static async Task SendInvocationError(string? invocationId, await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); } - private void InitializeHub(THub hub, HubConnectionContext connection) + private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true) { - hub.Clients = connection.HubCallerClients; + hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit) { InvokeAllowed = invokeAllowed }; hub.Context = connection.HubCallerContext; hub.Groups = _hubContext.Groups; } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index e2a65ca7d1d9..8e6ec0fa0dc9 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -7,20 +7,20 @@ internal sealed class HubCallerClients : IHubCallerClients { private readonly string _connectionId; private readonly IHubClients _hubClients; - private readonly string[] _currentConnectionId; - private readonly bool _parallelEnabled; + internal readonly ChannelBasedSemaphore _parallelInvokes; + + internal int ShouldReleaseSemaphore = 1; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called // so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance internal bool InvokeAllowed { get; set; } - public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled) + public HubCallerClients(IHubClients hubClients, string connectionId, ChannelBasedSemaphore parallelInvokes) { _connectionId = connectionId; _hubClients = hubClients; - _currentConnectionId = new[] { _connectionId }; - _parallelEnabled = parallelEnabled; + _parallelInvokes = parallelInvokes; } IClientProxy IHubCallerClients.Caller => Caller; @@ -28,19 +28,15 @@ public ISingleClientProxy Caller { get { - if (!_parallelEnabled) - { - return new NotParallelSingleClientProxy(_hubClients.Client(_connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(_connectionId); + return new SingleClientProxy(_hubClients.Client(_connectionId), this); } } - public IClientProxy Others => _hubClients.AllExcept(_currentConnectionId); + public IClientProxy Others => _hubClients.AllExcept(new[] { _connectionId }); public IClientProxy All => _hubClients.All; @@ -52,15 +48,11 @@ public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) IClientProxy IHubClients.Client(string connectionId) => Client(connectionId); public ISingleClientProxy Client(string connectionId) { - if (!_parallelEnabled) - { - return new NotParallelSingleClientProxy(_hubClients.Client(connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(connectionId); + return new SingleClientProxy(_hubClients.Client(connectionId), this); } public IClientProxy Group(string groupName) @@ -75,7 +67,7 @@ public IClientProxy Groups(IReadOnlyList groupNames) public IClientProxy OthersInGroup(string groupName) { - return _hubClients.GroupExcept(groupName, _currentConnectionId); + return _hubClients.GroupExcept(groupName, new[] { _connectionId }); } public IClientProxy GroupExcept(string groupName, IReadOnlyList excludedConnectionIds) @@ -98,18 +90,18 @@ public IClientProxy Users(IReadOnlyList userIds) return _hubClients.Users(userIds); } - private sealed class NotParallelSingleClientProxy : ISingleClientProxy + private sealed class NoInvokeSingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; - public NotParallelSingleClientProxy(ISingleClientProxy hubClients) + public NoInvokeSingleClientProxy(ISingleClientProxy hubClients) { _proxy = hubClients; } public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { - throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1."); + throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed."); } public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) @@ -118,18 +110,29 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance } } - private sealed class NoInvokeSingleClientProxy : ISingleClientProxy + private sealed class SingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; + private readonly HubCallerClients _hubCallerClients; - public NoInvokeSingleClientProxy(ISingleClientProxy hubClients) + public SingleClientProxy(ISingleClientProxy hubClients, HubCallerClients hubCallerClients) { _proxy = hubClients; + _hubCallerClients = hubCallerClients; } - public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + public async Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { - throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed."); + // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. + // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever + var value = Interlocked.CompareExchange(ref _hubCallerClients.ShouldReleaseSemaphore, 0, 1); + // Only release once, and we set ShouldReleaseSemaphore to 0 so the DefaultHubDispatcher knows not to call Release again + if (value == 1) + { + _hubCallerClients._parallelInvokes.Release(); + } + var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); + return result; } public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) diff --git a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs b/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs deleted file mode 100644 index a238d09643e3..000000000000 --- a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.SignalR.Internal; - -internal static class SemaphoreSlimExtensions -{ - public static Task RunAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - if (semaphoreSlim.Wait(0)) - { - _ = RunTask(callback, semaphoreSlim, state); - return Task.CompletedTask; - } - - return RunSlowAsync(semaphoreSlim, callback, state); - } - - private static async Task RunSlowAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - await semaphoreSlim.WaitAsync(); - return RunTask(callback, semaphoreSlim, state); - } - - static async Task RunTask(Func callback, SemaphoreSlim semaphoreSlim, TState state) - { - try - { - await callback(state); - } - finally - { - semaphoreSlim.Release(); - } - } -} diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 0b659acb416e..dc4ad919292e 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -338,6 +338,24 @@ public async Task GetClientResult(int num) var sum = await Clients.Caller.InvokeAsync("Sum", num, cancellationToken: default); return sum; } + + public void BackgroundClientResult(TcsService tcsService) + { + var caller = Clients.Caller; + _ = Task.Run(async () => + { + try + { + await tcsService.StartedMethod.Task; + var result = await caller.InvokeAsync("GetResult", 1, CancellationToken.None); + tcsService.EndMethod.SetResult(result); + } + catch (Exception ex) + { + tcsService.EndMethod.SetException(ex); + } + }); + } } internal class SelfRef diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 74b5dcec0b8f..1320c674d622 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -14,11 +14,7 @@ public async Task CanReturnClientResultToHub() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) @@ -47,10 +43,8 @@ public async Task CanReturnClientResultErrorToHub() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -74,36 +68,6 @@ public async Task CanReturnClientResultErrorToHub() } } - [Fact] - public async Task ThrowsWhenParallelHubInvokesNotEnabled() - { - using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) - { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - builder.AddSignalR(o => - { - o.MaximumParallelInvocationsPerClient = 1; - o.EnableDetailedErrors = true; - }); - }, LoggerFactory); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - - var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); - - // Hub asks client for a result, this is an invocation message with an ID - var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); - Assert.Equal(invocationId, completionMessage.InvocationId); - Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. InvalidOperationException: Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.", - completionMessage.Error); - } - } - } - [Fact] public async Task ThrowsWhenUsedInOnConnectedAsync() { @@ -113,7 +77,6 @@ public async Task ThrowsWhenUsedInOnConnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -141,7 +104,6 @@ public async Task ThrowsWhenUsedInOnDisconnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -235,11 +197,7 @@ public async Task CanReturnClientResultToTypedHubTwoWays() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using var client = new TestClient(invocationBinder: new GetClientResultTwoWaysInvocationBinder()); @@ -266,6 +224,99 @@ public async Task CanReturnClientResultToTypedHubTwoWays() } } + [Fact] + public async Task ClientResultFromHubDoesNotBlockReceiveLoop() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + // block 1 of the 2 parallel invocations + _ = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.BlockingMethod), Array.Empty())).DefaultTimeout(); + + // make multiple invocations which would normally block the invocation processing + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("2", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId2 = await client.SendHubMessageAsync(new InvocationMessage("3", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId3 = await client.SendHubMessageAsync(new InvocationMessage("4", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Read all 3 invocation messages from the server, shows that the hub processing continued even though parallel invokes is 2 + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage2 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage3 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + + Assert.NotNull(invocationMessage.InvocationId); + Assert.NotNull(invocationMessage2.InvocationId); + Assert.NotNull(invocationMessage3.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(9L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + + res = 5 + ((long)invocationMessage2.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage2.InvocationId, res)).DefaultTimeout(); + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(10L, completion.Result); + Assert.Equal(invocationId2, completion.InvocationId); + + res = 6 + ((long)invocationMessage3.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage3.InvocationId, res)).DefaultTimeout(); + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(11L, completion.Result); + Assert.Equal(invocationId3, completion.InvocationId); + } + } + } + + [Fact] + public async Task ClientResultFromBackgroundThreadInHubMethodWorks() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var completionMessage = await client.InvokeAsync(nameof(MethodHub.BackgroundClientResult)).DefaultTimeout(); + + tcsService.StartedMethod.SetResult(null); + + var task = await Task.WhenAny(tcsService.EndMethod.Task, client.ReadAsync()).DefaultTimeout(); + if (task == tcsService.EndMethod.Task) + { + await tcsService.EndMethod.Task; + } + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await (Task)task); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + Assert.Equal(5, await tcsService.EndMethod.Task.DefaultTimeout()); + + // Make sure we can still do a Hub invocation and that the semaphore state didn't get messed up + completionMessage = await client.InvokeAsync(nameof(MethodHub.ValueMethod)).DefaultTimeout(); + Assert.Equal(43L, completionMessage.Result); + } + } + } + private class TestBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 8b35d4647590..32af4cfe6246 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2991,16 +2991,23 @@ public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout() await client.SendHubMessageAsync(PingMessage.Instance); // Call long running hub method - var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); + var hubMethodTask1 = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); await tcsService.StartedMethod.Task.DefaultTimeout(); + // Wait for server to start reading again + await customDuplex.WrappedPipeReader.WaitForReadStart().DefaultTimeout(); + // Send another invocation to server, since we use Inline scheduling we know that once this call completes the server will have read and processed + // the message, it should be stuck waiting for the in-progress invoke now + _ = await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod)).DefaultTimeout(); + // Tick heartbeat while hub method is running to show that close isn't triggered client.TickHeartbeat(); // Unblock long running hub method tcsService.EndMethod.SetResult(null); - await hubMethodTask.DefaultTimeout(); + await hubMethodTask1.DefaultTimeout(); + await client.ReadAsync().DefaultTimeout(); // There is a small window when the hub method finishes and the timer starts again // So we need to delay a little before ticking the heart beat.