Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 7 additions & 29 deletions src/libraries/Common/tests/System/Net/Configuration.WebSockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,13 @@ public static partial class WebSockets
public static readonly Uri RemoteEchoHeadersServer = new Uri("ws://" + Host + "/" + EchoHeadersHandler);
public static readonly Uri SecureRemoteEchoHeadersServer = new Uri("wss://" + SecureHost + "/" + EchoHeadersHandler);

public static object[][] GetEchoServers()
{
if (PlatformDetection.IsFirefox)
{
// https://github.com/dotnet/runtime/issues/101115
return new object[][] {
new object[] { RemoteEchoServer },
};
}
return new object[][] {
new object[] { RemoteEchoServer },
new object[] { SecureRemoteEchoServer },
};
}

public static object[][] GetEchoHeadersServers()
{
if (PlatformDetection.IsFirefox)
{
// https://github.com/dotnet/runtime/issues/101115
return new object[][] {
new object[] { RemoteEchoHeadersServer },
};
}
return new object[][] {
new object[] { RemoteEchoHeadersServer },
new object[] { SecureRemoteEchoHeadersServer },
};
}
public static Uri[] GetEchoServers() => PlatformDetection.IsFirefox
? [ RemoteEchoServer ] // https://github.com/dotnet/runtime/issues/101115
: [ RemoteEchoServer, SecureRemoteEchoServer ];

public static Uri[] GetEchoHeadersServers() => PlatformDetection.IsFirefox
? [ RemoteEchoHeadersServer ] // https://github.com/dotnet/runtime/issues/101115
: [ RemoteEchoHeadersServer, SecureRemoteEchoHeadersServer ];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,68 @@ public class Http2LoopbackConnection : GenericLoopbackConnection
private readonly TimeSpan _timeout;
private int _lastStreamId;
private bool _expectClientDisconnect;
private readonly SemaphoreSlim? _readLock;
private readonly SemaphoreSlim? _writeLock;

private readonly byte[] _prefix = new byte[24];
public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length);
public bool IsInvalid => _connectionSocket == null;
public Stream Stream => _connectionStream;
public Task<bool> SettingAckWaiter => _ignoredSettingsAckPromise?.Task;

private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse)
private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, Http2Options httpOptions)
{
_connectionSocket = socket;
_connectionStream = stream;
_timeout = timeout;
_transparentPingResponse = transparentPingResponse;
_transparentPingResponse = httpOptions.EnableTransparentPingResponse;

if (httpOptions.EnsureThreadSafeIO)
{
_readLock = new SemaphoreSlim(1, 1);
_writeLock = new SemaphoreSlim(1, 1);
_connectionStream = CreateConcurrentConnectionStream(stream, _readLock, _writeLock);
}

static Stream CreateConcurrentConnectionStream(Stream stream, SemaphoreSlim readLock, SemaphoreSlim writeLock)
{
return new DelegateStream(
canReadFunc: () => true,
canWriteFunc: () => true,
readAsyncFunc: async (buffer, offset, count, cancellationToken) =>
{
await readLock.WaitAsync(cancellationToken);
try
{
return await stream.ReadAsync(buffer, offset, count, cancellationToken);
}
finally
{
readLock.Release();
}
},
writeAsyncFunc: async (buffer, offset, count, cancellationToken) =>
{
await writeLock.WaitAsync(cancellationToken);
try
{
await stream.WriteAsync(buffer, offset, count, cancellationToken);
await stream.FlushAsync(cancellationToken);
}
finally
{
writeLock.Release();
}
},
disposeFunc: (disposing) =>
{
if (disposing)
{
stream.Dispose();
}
}
);
}
}

public override string ToString()
Expand Down Expand Up @@ -83,7 +132,7 @@ public static async Task<Http2LoopbackConnection> CreateAsync(SocketWrapper sock
stream = sslStream;
}

var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse);
var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions);
await con.ReadPrefixAsync().ConfigureAwait(false);

return con;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ public class Http2Options : GenericLoopbackOptions
public bool ClientCertificateRequired { get; set; }

public bool EnableTransparentPingResponse { get; set; } = true;
public bool EnsureThreadSafeIO { get; set; }

public Http2Options()
{
Expand Down Expand Up @@ -216,7 +217,12 @@ public override async Task<GenericLoopbackConnection> CreateConnectionAsync(Sock

private static Http2Options CreateOptions(GenericLoopbackOptions options)
{
Http2Options http2Options = new Http2Options();
if (options is Http2Options http2Options)
{
return http2Options;
}

http2Options = new Http2Options();
if (options != null)
{
http2Options.Address = options.Address;
Expand Down
5 changes: 5 additions & 0 deletions src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,11 @@ public override async Task<GenericLoopbackConnection> CreateConnectionAsync(Sock

private static LoopbackServer.Options CreateOptions(GenericLoopbackOptions options)
{
if (options is LoopbackServer.Options { } loopbackOptions)
{
return loopbackOptions;
}

LoopbackServer.Options newOptions = new LoopbackServer.Options();
if (options != null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<Project>
<PropertyGroup>
<RepositoryRoot Condition="'$(RepositoryRoot)' == ''">$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, global.json))/</RepositoryRoot>
<CommonTestPath Condition="'$(CommonTestPath)' == ''">$([MSBuild]::NormalizeDirectory('$(RepositoryRoot)', 'src', 'libraries', 'Common', 'tests'))</CommonTestPath>
</PropertyGroup>

<Import Project="$([MSBuild]::NormalizePath($(RepositoryRoot), 'eng', 'testing', 'ForXHarness.Directory.Build.targets'))" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,186 +3,33 @@

using System;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Net.Test.Common;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;

namespace NetCoreServer
{
public class EchoWebSocketHandler
{
private const int MaxBufferSize = 128 * 1024;

public static async Task InvokeAsync(HttpContext context)
{
QueryString queryString = context.Request.QueryString;
bool replyWithPartialMessages = queryString.HasValue && queryString.Value.Contains("replyWithPartialMessages");
bool replyWithEnhancedCloseMessage = queryString.HasValue && queryString.Value.Contains("replyWithEnhancedCloseMessage");

string subProtocol = context.Request.Query["subprotocol"];

if (context.Request.QueryString.HasValue && context.Request.QueryString.Value.Contains("delay10sec"))
{
await Task.Delay(10000);
}
else if (context.Request.QueryString.HasValue && context.Request.QueryString.Value.Contains("delay20sec"))
{
await Task.Delay(20000);
}

var queryString = context.Request.QueryString.ToUriComponent(); // Returns empty string if request URI has no query
WebSocketEchoOptions options = await WebSocketEchoHelper.ProcessOptions(queryString);
try
{
if (!context.WebSockets.IsWebSocketRequest)
WebSocket socket = await WebSocketAcceptHelper.AcceptAsync(context, options.SubProtocol);
if (socket is null)
{
context.Response.StatusCode = 200;
context.Response.ContentType = "text/plain";
await context.Response.WriteAsync("Not a websocket request");

return;
}

WebSocket socket;
if (!string.IsNullOrEmpty(subProtocol))
{
socket = await context.WebSockets.AcceptWebSocketAsync(subProtocol);
}
else
{
socket = await context.WebSockets.AcceptWebSocketAsync();
}

await ProcessWebSocketRequest(socket, replyWithPartialMessages, replyWithEnhancedCloseMessage);
await WebSocketEchoHelper.RunEchoAll(
socket, options.ReplyWithPartialMessages, options.ReplyWithEnhancedCloseMessage);
}
catch (Exception)
{
// We might want to log these exceptions. But for now we ignore them.
}
}

private static async Task ProcessWebSocketRequest(
WebSocket socket,
bool replyWithPartialMessages,
bool replyWithEnhancedCloseMessage)
{
var receiveBuffer = new byte[MaxBufferSize];
var throwAwayBuffer = new byte[MaxBufferSize];

// Stay in loop while websocket is open
while (socket.State == WebSocketState.Open || socket.State == WebSocketState.CloseSent)
{
var receiveResult = await socket.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
if (receiveResult.CloseStatus == WebSocketCloseStatus.Empty)
{
await socket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None);
}
else
{
WebSocketCloseStatus closeStatus = receiveResult.CloseStatus.GetValueOrDefault();
await socket.CloseAsync(
closeStatus,
replyWithEnhancedCloseMessage ?
("Server received: " + (int)closeStatus + " " + receiveResult.CloseStatusDescription) :
receiveResult.CloseStatusDescription,
CancellationToken.None);
}

continue;
}

// Keep reading until we get an entire message.
int offset = receiveResult.Count;
while (receiveResult.EndOfMessage == false)
{
if (offset < MaxBufferSize)
{
receiveResult = await socket.ReceiveAsync(
new ArraySegment<byte>(receiveBuffer, offset, MaxBufferSize - offset),
CancellationToken.None);
}
else
{
receiveResult = await socket.ReceiveAsync(
new ArraySegment<byte>(throwAwayBuffer),
CancellationToken.None);
}

offset += receiveResult.Count;
}

// Close socket if the message was too big.
if (offset > MaxBufferSize)
{
await socket.CloseAsync(
WebSocketCloseStatus.MessageTooBig,
String.Format("{0}: {1} > {2}", WebSocketCloseStatus.MessageTooBig.ToString(), offset, MaxBufferSize),
CancellationToken.None);

continue;
}

bool sendMessage = false;
string receivedMessage = null;
if (receiveResult.MessageType == WebSocketMessageType.Text)
{
receivedMessage = Encoding.UTF8.GetString(receiveBuffer, 0, offset);
if (receivedMessage == ".close")
{
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, receivedMessage, CancellationToken.None);
}
else if (receivedMessage == ".shutdown")
{
await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, receivedMessage, CancellationToken.None);
}
else if (receivedMessage == ".abort")
{
socket.Abort();
}
else if (receivedMessage == ".delay5sec")
{
await Task.Delay(5000);
}
else if (receivedMessage == ".receiveMessageAfterClose")
{
byte[] buffer = new byte[1024];
string message = $"{receivedMessage} {DateTime.Now.ToString("HH:mm:ss")}";
buffer = System.Text.Encoding.UTF8.GetBytes(message);
await socket.SendAsync(
new ArraySegment<byte>(buffer, 0, message.Length),
WebSocketMessageType.Text,
true,
CancellationToken.None);
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, receivedMessage, CancellationToken.None);
}
else if (socket.State == WebSocketState.Open)
{
sendMessage = true;
}
}
else
{
sendMessage = true;
}

if (sendMessage)
{
await socket.SendAsync(
new ArraySegment<byte>(receiveBuffer, 0, offset),
receiveResult.MessageType,
!replyWithPartialMessages,
CancellationToken.None);
}
if (receivedMessage == ".closeafter")
{
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, receivedMessage, CancellationToken.None);
}
else if (receivedMessage == ".shutdownafter")
{
await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, receivedMessage, CancellationToken.None);
}
}
}
}
}
Loading
Loading