Skip to content

Commit 5ca9116

Browse files
authored
Call HttpContext.Abort() when WebSocket.Abort() is called. (#48892)
* Call HttpContext.Abort() when WebSocket.Abort() is called.
1 parent 477b22a commit 5ca9116

File tree

3 files changed

+248
-44
lines changed

3 files changed

+248
-44
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Net.WebSockets;
5+
using Microsoft.AspNetCore.Http;
6+
7+
namespace Microsoft.AspNetCore.WebSockets;
8+
9+
/// <summary>
10+
/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted
11+
/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket.
12+
/// </summary>
13+
internal sealed class ServerWebSocket : WebSocket
14+
{
15+
private readonly WebSocket _wrappedSocket;
16+
private readonly HttpContext _context;
17+
18+
internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context)
19+
{
20+
ArgumentNullException.ThrowIfNull(wrappedSocket);
21+
ArgumentNullException.ThrowIfNull(context);
22+
23+
_wrappedSocket = wrappedSocket;
24+
_context = context;
25+
}
26+
27+
public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus;
28+
29+
public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription;
30+
31+
public override WebSocketState State => _wrappedSocket.State;
32+
33+
public override string? SubProtocol => _wrappedSocket.SubProtocol;
34+
35+
public override void Abort()
36+
{
37+
_wrappedSocket.Abort();
38+
_context.Abort();
39+
}
40+
41+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
42+
{
43+
return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
44+
}
45+
46+
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
47+
{
48+
return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
49+
}
50+
51+
public override void Dispose()
52+
{
53+
_wrappedSocket.Dispose();
54+
}
55+
56+
public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
57+
{
58+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
59+
}
60+
61+
public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
62+
{
63+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
64+
}
65+
66+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
67+
{
68+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
69+
}
70+
71+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
72+
{
73+
return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);
74+
}
75+
76+
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
77+
{
78+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
79+
}
80+
}

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
208208
// Disable request timeout, if there is one, after the websocket has been accepted
209209
_context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
210210

211-
return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
211+
var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
212212
{
213213
IsServer = true,
214214
KeepAliveInterval = keepAliveInterval,
215215
SubProtocol = subProtocol,
216216
DangerousDeflateOptions = deflateOptions
217217
});
218+
219+
return new ServerWebSocket(wrappedSocket, _context);
218220
}
219221

220222
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)

0 commit comments

Comments
 (0)