diff --git a/src/Hosting/Hosting.slnf b/src/Hosting/Hosting.slnf index 80367a5d5c24..55ec034e1579 100644 --- a/src/Hosting/Hosting.slnf +++ b/src/Hosting/Hosting.slnf @@ -17,14 +17,15 @@ "src\\Hosting\\test\\FunctionalTests\\Microsoft.AspNetCore.Hosting.FunctionalTests.csproj", "src\\Hosting\\test\\testassets\\IStartupInjectionAssemblyName\\IStartupInjectionAssemblyName.csproj", "src\\Hosting\\test\\testassets\\TestStartupAssembly1\\TestStartupAssembly1.csproj", + "src\\Http\\Features\\src\\Microsoft.Extensions.Features.csproj", "src\\Http\\Headers\\src\\Microsoft.Net.Http.Headers.csproj", "src\\Http\\Http.Abstractions\\src\\Microsoft.AspNetCore.Http.Abstractions.csproj", "src\\Http\\Http.Extensions\\src\\Microsoft.AspNetCore.Http.Extensions.csproj", - "src\\Http\\Features\\src\\Microsoft.Extensions.Features.csproj", "src\\Http\\Http.Features\\src\\Microsoft.AspNetCore.Http.Features.csproj", "src\\Http\\Http\\src\\Microsoft.AspNetCore.Http.csproj", "src\\Http\\Owin\\src\\Microsoft.AspNetCore.Owin.csproj", "src\\Http\\WebUtilities\\src\\Microsoft.AspNetCore.WebUtilities.csproj", + "src\\Middleware\\WebSockets\\src\\Microsoft.AspNetCore.WebSockets.csproj", "src\\ObjectPool\\src\\Microsoft.Extensions.ObjectPool.csproj", "src\\Servers\\Connections.Abstractions\\src\\Microsoft.AspNetCore.Connections.Abstractions.csproj", "src\\Servers\\Kestrel\\Core\\src\\Microsoft.AspNetCore.Server.Kestrel.Core.csproj", diff --git a/src/Hosting/TestHost/src/HttpContextBuilder.cs b/src/Hosting/TestHost/src/HttpContextBuilder.cs index 9534fee51491..ed251d595de2 100644 --- a/src/Hosting/TestHost/src/HttpContextBuilder.cs +++ b/src/Hosting/TestHost/src/HttpContextBuilder.cs @@ -57,6 +57,7 @@ internal HttpContextBuilder(ApplicationWrapper application, bool allowSynchronou _httpContext.Features.Set(_responseFeature); _httpContext.Features.Set(_requestLifetimeFeature); _httpContext.Features.Set(_responseTrailersFeature); + _httpContext.Features.Set(new UpgradeFeature()); } public bool AllowSynchronousIO { get; set; } diff --git a/src/Hosting/TestHost/src/UpgradeFeature.cs b/src/Hosting/TestHost/src/UpgradeFeature.cs new file mode 100644 index 000000000000..15cdb031ad36 --- /dev/null +++ b/src/Hosting/TestHost/src/UpgradeFeature.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.TestHost +{ + internal class UpgradeFeature : IHttpUpgradeFeature + { + public bool IsUpgradableRequest => false; + + // TestHost provides an IHttpWebSocketFeature so it wont call UpgradeAsync() + public Task UpgradeAsync() + { + throw new NotSupportedException(); + } + } +} diff --git a/src/Hosting/TestHost/test/Microsoft.AspNetCore.TestHost.Tests.csproj b/src/Hosting/TestHost/test/Microsoft.AspNetCore.TestHost.Tests.csproj index 6f77fe58d6e6..820cf3800f65 100644 --- a/src/Hosting/TestHost/test/Microsoft.AspNetCore.TestHost.Tests.csproj +++ b/src/Hosting/TestHost/test/Microsoft.AspNetCore.TestHost.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/src/Hosting/TestHost/test/TestClientTests.cs b/src/Hosting/TestHost/test/TestClientTests.cs index ba175599e424..76eada9f0395 100644 --- a/src/Hosting/TestHost/test/TestClientTests.cs +++ b/src/Hosting/TestHost/test/TestClientTests.cs @@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; @@ -960,5 +961,35 @@ public async Task SendAsync_ExplicitlySet_Protocol20() Assert.Equal(expected, actual); Assert.Equal(new Version(2, 0), message.Version); } + + [Fact] + public async Task VerifyWebSocketAndUpgradeFeaturesForNonWebSocket() + { + using (var testServer = new TestServer(new WebHostBuilder() + .Configure(app => + { + app.UseWebSockets(); + app.Run(async c => + { + var upgradeFeature = c.Features.Get(); + // Feature needs to exist for SignalR to verify that the server supports WebSockets + Assert.NotNull(upgradeFeature); + Assert.False(upgradeFeature.IsUpgradableRequest); + await Assert.ThrowsAsync(() => upgradeFeature.UpgradeAsync()); + + var webSocketFeature = c.Features.Get(); + Assert.NotNull(webSocketFeature); + Assert.False(webSocketFeature.IsWebSocketRequest); + + await c.Response.WriteAsync("test"); + }); + }))) + { + var client = testServer.CreateClient(); + + var actual = await client.GetStringAsync("http://localhost:12345/"); + Assert.Equal("test", actual); + } + } } } diff --git a/src/Hosting/TestHost/test/WebSocketClientTests.cs b/src/Hosting/TestHost/test/WebSocketClientTests.cs index 40c3b234654e..acb3960c5154 100644 --- a/src/Hosting/TestHost/test/WebSocketClientTests.cs +++ b/src/Hosting/TestHost/test/WebSocketClientTests.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http.Features; using Xunit; namespace Microsoft.AspNetCore.TestHost.Tests @@ -54,5 +55,77 @@ await client.ConnectAsync( Assert.Equal(expectedHost, capturedHost); Assert.Equal("/connect", capturedPath); } + + [Fact] + public async Task CanAcceptWebSocket() + { + using (var testServer = new TestServer(new WebHostBuilder() + .Configure(app => + { + app.UseWebSockets(); + app.Run(async ctx => + { + if (ctx.Request.Path.StartsWithSegments("/connect")) + { + if (ctx.WebSockets.IsWebSocketRequest) + { + using var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + var buffer = new byte[1000]; + var res = await websocket.ReceiveAsync(buffer, default); + await websocket.SendAsync(buffer.AsMemory(0, res.Count), System.Net.WebSockets.WebSocketMessageType.Binary, true, default); + await websocket.CloseAsync(System.Net.WebSockets.WebSocketCloseStatus.NormalClosure, null, default); + } + } + }); + }))) + { + var client = testServer.CreateWebSocketClient(); + + using var socket = await client.ConnectAsync( + uri: new Uri("http://localhost/connect"), + cancellationToken: default); + + await socket.SendAsync(new byte[10], System.Net.WebSockets.WebSocketMessageType.Binary, true, default); + var res = await socket.ReceiveAsync(new byte[100], default); + Assert.Equal(10, res.Count); + Assert.True(res.EndOfMessage); + + await socket.CloseAsync(System.Net.WebSockets.WebSocketCloseStatus.NormalClosure, null, default); + } + } + + [Fact] + public async Task VerifyWebSocketAndUpgradeFeatures() + { + using (var testServer = new TestServer(new WebHostBuilder() + .Configure(app => + { + app.Run(async c => + { + var upgradeFeature = c.Features.Get(); + Assert.NotNull(upgradeFeature); + Assert.False(upgradeFeature.IsUpgradableRequest); + await Assert.ThrowsAsync(() => upgradeFeature.UpgradeAsync()); + + var webSocketFeature = c.Features.Get(); + Assert.NotNull(webSocketFeature); + Assert.True(webSocketFeature.IsWebSocketRequest); + }); + }))) + { + var client = testServer.CreateWebSocketClient(); + + try + { + using var socket = await client.ConnectAsync( + uri: new Uri("http://localhost/connect"), + cancellationToken: default); + } + catch + { + // An exception will be thrown because our endpoint does not accept the websocket + } + } + } } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/src/SignalR/clients/csharp/Client/test/UnitTests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index 7c584a84bc8b..44b2b2fb6207 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -18,6 +18,8 @@ + + diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestServerTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestServerTests.cs new file mode 100644 index 000000000000..561398b95f0f --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestServerTests.cs @@ -0,0 +1,121 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public class TestServerTests : VerifiableLoggedTest + { + [Fact] + public async Task WebSocketsWorks() + { + using (StartVerifiableLog()) + { + var builder = new WebHostBuilder().ConfigureServices(s => + { + s.AddLogging(); + s.AddSingleton(LoggerFactory); + s.AddSignalR(); + }).Configure(app => + { + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapHub("/echo"); + }); + }); + var server = new TestServer(builder); + + var webSocketFactoryCalled = false; + var connectionBuilder = new HubConnectionBuilder() + .WithUrl(server.BaseAddress + "echo", options => + { + options.Transports = Http.Connections.HttpTransportType.WebSockets; + options.HttpMessageHandlerFactory = _ => + { + return server.CreateHandler(); + }; + options.WebSocketFactory = async (context, token) => + { + webSocketFactoryCalled = true; + var wsClient = server.CreateWebSocketClient(); + return await wsClient.ConnectAsync(context.Uri, default); + }; + }); + connectionBuilder.Services.AddLogging(); + connectionBuilder.Services.AddSingleton(LoggerFactory); + var connection = connectionBuilder.Build(); + + var originalMessage = "message"; + connection.On("Echo", (receivedMessage) => + { + Assert.Equal(originalMessage, receivedMessage); + }); + + await connection.StartAsync(); + await connection.InvokeAsync("Echo", originalMessage); + Assert.True(webSocketFactoryCalled); + } + } + + [Fact] + public async Task LongPollingWorks() + { + using (StartVerifiableLog()) + { + var builder = new WebHostBuilder().ConfigureServices(s => + { + s.AddLogging(); + s.AddSingleton(LoggerFactory); + s.AddSignalR(); + }).Configure(app => + { + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapHub("/echo"); + }); + }); + var server = new TestServer(builder); + + var connectionBuilder = new HubConnectionBuilder() + .WithUrl(server.BaseAddress + "echo", options => + { + options.Transports = Http.Connections.HttpTransportType.LongPolling; + options.HttpMessageHandlerFactory = _ => + { + return server.CreateHandler(); + }; + }); + connectionBuilder.Services.AddLogging(); + connectionBuilder.Services.AddSingleton(LoggerFactory); + var connection = connectionBuilder.Build(); + + var originalMessage = "message"; + connection.On("Echo", (receivedMessage) => + { + Assert.Equal(originalMessage, receivedMessage); + }); + + await connection.StartAsync(); + await connection.InvokeAsync("Echo", originalMessage); + } + } + } + + class EchoHub : Hub + { + public Task Echo(string message) + { + return Clients.All.SendAsync("Echo", message); + } + } +}