Skip to content

[SignalR] Avoid blocking common InvokeAsync usage #42796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/SignalR/server/Core/src/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,8 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo
_systemClock = contextOptions.SystemClock ?? new SystemClock();
_lastSendTick = _systemClock.CurrentTicks;

// We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it
var maxInvokeLimit = contextOptions.MaximumParallelInvocations;
if (maxInvokeLimit != 1)
{
ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit);
}
ActiveInvocationLimit = new ChannelBasedSemaphore(maxInvokeLimit);
}

internal StreamTracker StreamTracker
Expand All @@ -102,11 +98,10 @@ internal StreamTracker StreamTracker
}

internal HubCallerContext HubCallerContext { get; }
internal HubCallerClients HubCallerClients { get; set; } = null!;

internal Exception? CloseException { get; private set; }

internal SemaphoreSlim? ActiveInvocationLimit { get; }
internal ChannelBasedSemaphore ActiveInvocationLimit { get; }

/// <summary>
/// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
Expand Down
3 changes: 3 additions & 0 deletions src/SignalR/server/Core/src/HubConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep
// Ensure the connection is aborted before firing disconnect
await connection.AbortAsync();

// If a client result is requested in OnDisconnectedAsync we want to avoid the SemaphoreFullException and get the better connection disconnected IOException
_ = connection.ActiveInvocationLimit.TryAcquire();

try
{
await _dispatcher.OnDisconnectedAsync(connection, exception);
Expand Down
77 changes: 77 additions & 0 deletions src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Threading.Channels;

namespace Microsoft.AspNetCore.SignalR.Internal;

// Use a Channel instead of a SemaphoreSlim so that we can potentially save task allocations (ValueTask!)
// Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim
internal sealed class ChannelBasedSemaphore
{
private readonly Channel<int> _channel;

public ChannelBasedSemaphore(int maxCapacity)
{
_channel = Channel.CreateBounded<int>(maxCapacity);
for (var i = 0; i < maxCapacity; i++)
{
_channel.Writer.TryWrite(1);
}
}

public bool TryAcquire()
{
return _channel.Reader.TryRead(out _);
}

// The int result isn't important, only reason it's exposed is because ValueTask<T> doesn't implement ValueTask so we can't cast like we could with Task<T> to Task
public ValueTask<int> WaitAsync(CancellationToken cancellationToken = default)
{
return _channel.Reader.ReadAsync(cancellationToken);
}

public void Release()
{
if (!_channel.Writer.TryWrite(1))
{
throw new SemaphoreFullException();
}
}

public ValueTask RunAsync<TState>(Func<TState, Task<bool>> callback, TState state)
{
if (TryAcquire())
{
_ = RunTask(callback, state);
return ValueTask.CompletedTask;
}

return RunSlowAsync(callback, state);
}

private async ValueTask RunSlowAsync<TState>(Func<TState, Task<bool>> callback, TState state)
{
_ = await WaitAsync();
_ = RunTask(callback, state);
}

private async Task RunTask<TState>(Func<TState, Task<bool>> callback, TState state)
{
try
{
var shouldRelease = await callback(state);
if (shouldRelease)
{
Release();
}
}
catch
{
// DefaultHubDispatcher catches and handles exceptions
// It does write to the connection in exception cases which also can't throw because we catch and log in HubConnectionContext
Debug.Assert(false);
}
}
}
30 changes: 17 additions & 13 deletions src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex
public override async Task OnConnectedAsync(HubConnectionContext connection)
{
await using var scope = _serviceScopeFactory.CreateAsyncScope();
connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null);

var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
var hub = hubActivator.Create();
try
{
InitializeHub(hub, connection);
// OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync)
InitializeHub(hub, connection, invokeAllowed: false);

if (_onConnectedMiddleware != null)
{
Expand All @@ -90,9 +90,6 @@ public override async Task OnConnectedAsync(HubConnectionContext connection)
{
await hub.OnConnectedAsync();
}

// OnConnectedAsync is finished, allow hub methods to use client results (ISingleClientProxy.InvokeAsync)
connection.HubCallerClients.InvokeAllowed = true;
}
finally
{
Expand Down Expand Up @@ -256,13 +253,13 @@ private Task ProcessInvocation(HubConnectionContext connection,
else
{
bool isStreamCall = descriptor.StreamingParameters != null;
if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse)
if (!isStreamCall && !isStreamResponse)
{
return connection.ActiveInvocationLimit.RunAsync(static state =>
{
var (dispatcher, descriptor, connection, invocationMessage) = state;
return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false);
}, (this, descriptor, connection, hubMethodInvocationMessage));
}, (this, descriptor, connection, hubMethodInvocationMessage)).AsTask();
}
else
{
Expand All @@ -271,11 +268,12 @@ private Task ProcessInvocation(HubConnectionContext connection,
}
}

private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
private async Task<bool> Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall)
{
var methodExecutor = descriptor.MethodExecutor;

var wasSemaphoreReleased = false;
var disposeScope = true;
var scope = _serviceScopeFactory.CreateAsyncScope();
IHubActivator<THub>? hubActivator = null;
Expand All @@ -290,12 +288,12 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext c
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized");
return;
return true;
}

if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection))
{
return;
return true;
}

try
Expand All @@ -308,7 +306,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
return;
return true;
}

InitializeHub(hub, connection);
Expand Down Expand Up @@ -404,9 +402,15 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
{
if (disposeScope)
{
if (hub?.Clients is HubCallerClients hubCallerClients)
{
wasSemaphoreReleased = Interlocked.CompareExchange(ref hubCallerClients.ShouldReleaseSemaphore, 0, 1) == 0;
}
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
}
}

return !wasSemaphoreReleased;
}

private static ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator<THub>? hubActivator,
Expand Down Expand Up @@ -553,9 +557,9 @@ private static async Task SendInvocationError(string? invocationId,
await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage));
}

private void InitializeHub(THub hub, HubConnectionContext connection)
private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true)
{
hub.Clients = connection.HubCallerClients;
hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit) { InvokeAllowed = invokeAllowed };
hub.Context = connection.HubCallerContext;
hub.Groups = _hubContext.Groups;
}
Expand Down
51 changes: 27 additions & 24 deletions src/SignalR/server/Core/src/Internal/HubCallerClients.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,36 @@ internal sealed class HubCallerClients : IHubCallerClients
{
private readonly string _connectionId;
private readonly IHubClients _hubClients;
private readonly string[] _currentConnectionId;
private readonly bool _parallelEnabled;
internal readonly ChannelBasedSemaphore _parallelInvokes;

internal int ShouldReleaseSemaphore = 1;

// Client results don't work in OnConnectedAsync
// This property is set by the hub dispatcher when those methods are being called
// so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance
internal bool InvokeAllowed { get; set; }

public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled)
public HubCallerClients(IHubClients hubClients, string connectionId, ChannelBasedSemaphore parallelInvokes)
{
_connectionId = connectionId;
_hubClients = hubClients;
_currentConnectionId = new[] { _connectionId };
_parallelEnabled = parallelEnabled;
_parallelInvokes = parallelInvokes;
}

IClientProxy IHubCallerClients<IClientProxy>.Caller => Caller;
public ISingleClientProxy Caller
{
get
{
if (!_parallelEnabled)
{
return new NotParallelSingleClientProxy(_hubClients.Client(_connectionId));
}
if (!InvokeAllowed)
{
return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId));
}
return _hubClients.Client(_connectionId);
return new SingleClientProxy(_hubClients.Client(_connectionId), this);
}
}

public IClientProxy Others => _hubClients.AllExcept(_currentConnectionId);
public IClientProxy Others => _hubClients.AllExcept(new[] { _connectionId });

public IClientProxy All => _hubClients.All;

Expand All @@ -52,15 +48,11 @@ public IClientProxy AllExcept(IReadOnlyList<string> excludedConnectionIds)
IClientProxy IHubClients<IClientProxy>.Client(string connectionId) => Client(connectionId);
public ISingleClientProxy Client(string connectionId)
{
if (!_parallelEnabled)
{
return new NotParallelSingleClientProxy(_hubClients.Client(connectionId));
}
if (!InvokeAllowed)
{
return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId));
}
return _hubClients.Client(connectionId);
return new SingleClientProxy(_hubClients.Client(connectionId), this);
}

public IClientProxy Group(string groupName)
Expand All @@ -75,7 +67,7 @@ public IClientProxy Groups(IReadOnlyList<string> groupNames)

public IClientProxy OthersInGroup(string groupName)
{
return _hubClients.GroupExcept(groupName, _currentConnectionId);
return _hubClients.GroupExcept(groupName, new[] { _connectionId });
}

public IClientProxy GroupExcept(string groupName, IReadOnlyList<string> excludedConnectionIds)
Expand All @@ -98,18 +90,18 @@ public IClientProxy Users(IReadOnlyList<string> userIds)
return _hubClients.Users(userIds);
}

private sealed class NotParallelSingleClientProxy : ISingleClientProxy
private sealed class NoInvokeSingleClientProxy : ISingleClientProxy
{
private readonly ISingleClientProxy _proxy;

public NotParallelSingleClientProxy(ISingleClientProxy hubClients)
public NoInvokeSingleClientProxy(ISingleClientProxy hubClients)
{
_proxy = hubClients;
}

public Task<T> InvokeCoreAsync<T>(string method, object?[] args, CancellationToken cancellationToken = default)
{
throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.");
throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed.");
}

public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default)
Expand All @@ -118,18 +110,29 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance
}
}

private sealed class NoInvokeSingleClientProxy : ISingleClientProxy
private sealed class SingleClientProxy : ISingleClientProxy
{
private readonly ISingleClientProxy _proxy;
private readonly HubCallerClients _hubCallerClients;

public NoInvokeSingleClientProxy(ISingleClientProxy hubClients)
public SingleClientProxy(ISingleClientProxy hubClients, HubCallerClients hubCallerClients)
{
_proxy = hubClients;
_hubCallerClients = hubCallerClients;
}

public Task<T> InvokeCoreAsync<T>(string method, object?[] args, CancellationToken cancellationToken = default)
public async Task<T> InvokeCoreAsync<T>(string method, object?[] args, CancellationToken cancellationToken = default)
{
throw new InvalidOperationException("Client results inside OnConnectedAsync Hub methods are not allowed.");
// Releases the Channel that is blocking pending invokes, which in turn can block the receive loop.
// Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever
var value = Interlocked.CompareExchange(ref _hubCallerClients.ShouldReleaseSemaphore, 0, 1);
// Only release once, and we set ShouldReleaseSemaphore to 0 so the DefaultHubDispatcher knows not to call Release again
if (value == 1)
{
_hubCallerClients._parallelInvokes.Release();
}
var result = await _proxy.InvokeCoreAsync<T>(method, args, cancellationToken);
return result;
}

public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default)
Expand Down
36 changes: 0 additions & 36 deletions src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs

This file was deleted.

Loading