diff --git a/samples/ReverseProxy.Metrics.Sample/Startup.cs b/samples/ReverseProxy.Metrics.Sample/Startup.cs index cdad7c9d4..0cf587190 100644 --- a/samples/ReverseProxy.Metrics.Sample/Startup.cs +++ b/samples/ReverseProxy.Metrics.Sample/Startup.cs @@ -44,6 +44,8 @@ public void ConfigureServices(IServiceCollection services) // Registration of a consumer to events for HttpClient telemetry // Note: this depends on changes implemented in .NET 5 services.AddTelemetryConsumer(); + + services.AddTelemetryConsumer(); } /// @@ -55,6 +57,9 @@ public void Configure(IApplicationBuilder app) // Placed at the beginning so it is the first and last thing run for each request app.UsePerRequestMetricCollection(); + // Middleware used to intercept the WebSocket connection and collect telemetry exposed to WebSocketsTelemetryConsumer + app.UseWebSocketsTelemetry(); + app.UseRouting(); app.UseEndpoints(endpoints => { diff --git a/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs b/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs new file mode 100644 index 000000000..fd9b066d2 --- /dev/null +++ b/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs @@ -0,0 +1,21 @@ +using System; +using Microsoft.Extensions.Logging; +using Yarp.Telemetry.Consumption; + +namespace Yarp.Sample +{ + public sealed class WebSocketsTelemetryConsumer : IWebSocketsTelemetryConsumer + { + private readonly ILogger _logger; + + public WebSocketsTelemetryConsumer(ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + _logger.LogInformation($"WebSocket connection closed ({closeReason}) after reading {messagesRead} and writing {messagesWritten} messages over {(timestamp - establishedTime).TotalSeconds:N2} seconds."); + } + } +} diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index 2d6c14ca0..919f37e47 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -576,7 +576,7 @@ private async ValueTask HandleUpgradedResponse(HttpContext conte var (secondResult, secondException) = await secondTask; if (secondResult != StreamCopyResult.Success) { - error = ReportResult(context, requestFinishedFirst, secondResult, secondException!); + error = ReportResult(context, !requestFinishedFirst, secondResult, secondException!); } else { diff --git a/test/ReverseProxy.Tests/Common/DelegatingStream.cs b/src/ReverseProxy/Utilities/DelegatingStream.cs similarity index 91% rename from test/ReverseProxy.Tests/Common/DelegatingStream.cs rename to src/ReverseProxy/Utilities/DelegatingStream.cs index 1466b620f..597f2a45a 100644 --- a/test/ReverseProxy.Tests/Common/DelegatingStream.cs +++ b/src/ReverseProxy/Utilities/DelegatingStream.cs @@ -1,5 +1,5 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. using System; using System.Diagnostics; @@ -7,8 +7,9 @@ using System.Threading; using System.Threading.Tasks; -namespace Yarp.Tests.Common +namespace Yarp.ReverseProxy.Utilities { + // Taken from https://github.com/dotnet/runtime/blob/00f37bc13b4edbba1afca9e98d74432a94f5192f/src/libraries/Common/src/System/IO/DelegatingStream.cs // Forwards all calls to an inner stream except where overridden in a derived class. internal abstract class DelegatingStream : Stream { @@ -113,9 +114,9 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return _innerStream.ReadAsync(buffer, cancellationToken); } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) { - return _innerStream.BeginRead(buffer, offset, count, callback, state); + return _innerStream.BeginRead(buffer, offset, count, callback!, state); } public override int EndRead(IAsyncResult asyncResult) @@ -167,9 +168,9 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo return _innerStream.WriteAsync(buffer, cancellationToken); } - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) { - return _innerStream.BeginWrite(buffer, offset, count, callback, state); + return _innerStream.BeginWrite(buffer, offset, count, callback!, state); } public override void EndWrite(IAsyncResult asyncResult) diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs new file mode 100644 index 000000000..949048cdf --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal enum WebSocketCloseReason : int + { + Unknown, + ClientGracefulClose, + ServerGracefulClose, + ClientDisconnect, + ServerDisconnect, + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs new file mode 100644 index 000000000..48931fb9f --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using Yarp.ReverseProxy.Utilities; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal unsafe struct WebSocketsParser + { + private const int MaskLength = 4; + private const int MinHeaderSize = 2; + private const int MaxHeaderSize = MinHeaderSize + MaskLength + sizeof(ulong); + + private fixed byte _leftoverBuffer[MaxHeaderSize - 1]; + private readonly byte _minHeaderSize; + private byte _leftover; + private ulong _bytesToSkip; + private long _closeTime; + private readonly IClock _clock; + + public long MessageCount { get; private set; } + + public DateTime? CloseTime => _closeTime == 0 ? null : new DateTime(_closeTime, DateTimeKind.Utc); + + public WebSocketsParser(IClock clock, bool isServer) + { + _minHeaderSize = (byte)(MinHeaderSize + (isServer ? MaskLength : 0)); + _leftover = 0; + _bytesToSkip = 0; + _closeTime = 0; + _clock = clock; + MessageCount = 0; + } + + // The WebSocket Protocol: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // +---------------------------------------------------------------+ + // + // The header can be 2-10 bytes long, followed by a 4 byte mask if the message was sent by the client. + // We have to read the first 2 bytes to know how long the frame header will be. + // Since the buffer may not contain the full frame, we make use of a leftoverBuffer + // where we store leftover bytes that don't represent a complete frame header. + // On the next call to Consume, we interpret the leftover bytes as the beginning of the frame. + // As we are not interested in the actual payload data, we skip over (payload length + mask length) bytes after each header. + public void Consume(ReadOnlySpan buffer) + { + int leftover = _leftover; + var bytesToSkip = _bytesToSkip; + + while (true) + { + var toSkip = Math.Min(bytesToSkip, (ulong)buffer.Length); + buffer = buffer.Slice((int)toSkip); + bytesToSkip -= toSkip; + + var available = leftover + buffer.Length; + int headerSize = _minHeaderSize; + + if (available < headerSize) + { + break; + } + + var length = (leftover > 1 ? _leftoverBuffer[1] : buffer[1 - leftover]) & 0x7FUL; + + if (length > 125) + { + // The actual length will be encoded in 2 or 8 bytes, based on whether the length was 126 or 127 + var lengthBytes = 2 << (((int)length & 1) << 1); + headerSize += lengthBytes; + Debug.Assert(leftover < headerSize); + + if (available < headerSize) + { + break; + } + + lengthBytes += MinHeaderSize; + + length = 0; + for (var i = MinHeaderSize; i < lengthBytes; i++) + { + length <<= 8; + length |= i < leftover ? _leftoverBuffer[i] : buffer[i - leftover]; + } + } + + Debug.Assert(leftover < headerSize); + bytesToSkip = length; + + const int NonReservedBitsMask = 0b_1000_1111; + var header = (leftover > 0 ? _leftoverBuffer[0] : buffer[0]) & NonReservedBitsMask; + + // Don't count control frames under MessageCount + if ((uint)(header - 0x80) <= 0x02) + { + // Has FIN (0x80) and is a Continuation (0x00) / Text (0x01) / Binary (0x02) opcode + MessageCount++; + } + else if ((header & 0xF) == 0x8) // CLOSE + { + if (_closeTime == 0) + { + _closeTime = _clock.GetUtcNow().Ticks; + } + } + + // Advance the buffer by the number of bytes read for the header, + // accounting for any bytes we may have read from the leftoverBuffer + buffer = buffer.Slice(headerSize - leftover); + leftover = 0; + } + + Debug.Assert(bytesToSkip == 0 || buffer.Length == 0); + _bytesToSkip = bytesToSkip; + + Debug.Assert(leftover + buffer.Length < MaxHeaderSize); + for (var i = 0; i < buffer.Length; i++, leftover++) + { + _leftoverBuffer[leftover] = buffer[i]; + } + + _leftover = (byte)leftover; + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs new file mode 100644 index 000000000..6fa10793e --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.Tracing; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + [EventSource(Name = "Yarp.ReverseProxy.WebSockets")] + internal sealed class WebSocketsTelemetry : EventSource + { + public static readonly WebSocketsTelemetry Log = new(); + + [Event(1, Level = EventLevel.Informational)] + public void WebSocketClosed(long establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.All)) + { + WriteEvent(eventId: 1, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs new file mode 100644 index 000000000..0ab8e5379 --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Linq; +using Microsoft.Extensions.DependencyInjection; +using Yarp.ReverseProxy.Utilities; +using Yarp.ReverseProxy.WebSocketsTelemetry; + +namespace Microsoft.AspNetCore.Builder +{ + /// + /// extension methods to add the . + /// + public static class WebSocketsTelemetryExtensions + { + /// + /// Adds a to the request pipeline. + /// Must be added before . + /// + public static IApplicationBuilder UseWebSocketsTelemetry(this IApplicationBuilder app) + { + return app.Use(next => + { + // Avoid exposing another extension method (AddWebSocketsTelemetry) just because of IClock + var clock = app.ApplicationServices.GetServices().FirstOrDefault() ?? new Clock(); + return new WebSocketsTelemetryMiddleware(next, clock).InvokeAsync; + }); + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs new file mode 100644 index 000000000..77c0bfde4 --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Net.Http.Headers; +using Yarp.ReverseProxy.Utilities; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal sealed class WebSocketsTelemetryMiddleware + { + private readonly RequestDelegate _next; + private readonly IClock _clock; + + public WebSocketsTelemetryMiddleware(RequestDelegate next, IClock clock) + { + _next = next ?? throw new ArgumentNullException(nameof(next)); + _clock = clock ?? throw new ArgumentNullException(nameof(clock)); + } + + public Task InvokeAsync(HttpContext context) + { + if (WebSocketsTelemetry.Log.IsEnabled()) + { + if (context.Features.Get() is { IsUpgradableRequest: true } upgradeFeature) + { + var upgradeWrapper = new HttpUpgradeFeatureWrapper(_clock, context, upgradeFeature); + return InvokeAsyncCore(upgradeWrapper, _next); + } + } + + return _next(context); + } + + private static async Task InvokeAsyncCore(HttpUpgradeFeatureWrapper upgradeWrapper, RequestDelegate next) + { + upgradeWrapper.HttpContext.Features.Set(upgradeWrapper); + + try + { + await next(upgradeWrapper.HttpContext); + } + finally + { + if (upgradeWrapper.TelemetryStream is { } telemetryStream) + { + WebSocketsTelemetry.Log.WebSocketClosed( + telemetryStream.EstablishedTime.Ticks, + telemetryStream.GetCloseReason(upgradeWrapper.HttpContext), + telemetryStream.MessagesRead, + telemetryStream.MessagesWritten); + } + + upgradeWrapper.HttpContext.Features.Set(upgradeWrapper.InnerUpgradeFeature); + } + } + + private sealed class HttpUpgradeFeatureWrapper : IHttpUpgradeFeature + { + private readonly IClock _clock; + + public HttpContext HttpContext { get; private set; } + + public IHttpUpgradeFeature InnerUpgradeFeature { get; private set; } + + public WebSocketsTelemetryStream? TelemetryStream { get; private set; } + + public bool IsUpgradableRequest => InnerUpgradeFeature.IsUpgradableRequest; + + public HttpUpgradeFeatureWrapper(IClock clock, HttpContext httpContext, IHttpUpgradeFeature upgradeFeature) + { + _clock = clock ?? throw new ArgumentNullException(nameof(clock)); + HttpContext = httpContext ?? throw new ArgumentNullException(nameof(httpContext)); + InnerUpgradeFeature = upgradeFeature ?? throw new ArgumentNullException(nameof(upgradeFeature)); + } + + public async Task UpgradeAsync() + { + Debug.Assert(TelemetryStream is null); + var opaqueTransport = await InnerUpgradeFeature.UpgradeAsync(); + + if (HttpContext.Response.Headers.TryGetValue(HeaderNames.Upgrade, out var upgradeValues) && + upgradeValues.Count == 1 && + string.Equals("WebSocket", upgradeValues.ToString(), StringComparison.OrdinalIgnoreCase)) + { + TelemetryStream = new WebSocketsTelemetryStream(_clock, opaqueTransport); + } + + return TelemetryStream ?? opaqueTransport; + } + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs new file mode 100644 index 000000000..f9f14d31b --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.Forwarder; +using Yarp.ReverseProxy.Utilities; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal sealed class WebSocketsTelemetryStream : DelegatingStream + { + private WebSocketsParser _readParser, _writeParser; + + public DateTime EstablishedTime { get; } + public long MessagesRead => _readParser.MessageCount; + public long MessagesWritten => _writeParser.MessageCount; + + public WebSocketsTelemetryStream(IClock clock, Stream innerStream) + : base(innerStream) + { + EstablishedTime = clock.GetUtcNow().UtcDateTime; + _readParser = new WebSocketsParser(clock, isServer: true); + _writeParser = new WebSocketsParser(clock, isServer: false); + } + + public WebSocketCloseReason GetCloseReason(HttpContext context) + { + var clientCloseTime = _readParser.CloseTime; + var serverCloseTime = _writeParser.CloseTime; + + // Mutual, graceful WebSocket close. We report whichever one we saw first. + if (clientCloseTime.HasValue && serverCloseTime.HasValue) + { + return clientCloseTime.Value < serverCloseTime.Value ? WebSocketCloseReason.ClientGracefulClose : WebSocketCloseReason.ServerGracefulClose; + } + + // One side sent a WebSocket close, but we never saw a response from the other side + // It is possible an error occurred, but we saw a graceful close first, so that is the intiator + if (clientCloseTime.HasValue) + { + return WebSocketCloseReason.ClientGracefulClose; + } + if (serverCloseTime.HasValue) + { + return WebSocketCloseReason.ServerGracefulClose; + } + + return context.Features.Get()?.Error switch + { + // Either side disconnected without sending a WebSocket close + ForwarderError.UpgradeRequestClient => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeRequestCanceled => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeResponseClient => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeResponseCanceled => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeRequestDestination => WebSocketCloseReason.ServerDisconnect, + ForwarderError.UpgradeResponseDestination => WebSocketCloseReason.ServerDisconnect, + + // Both sides gracefully closed the underlying connection without sending a WebSocket close + // Neither side is doing what we recognize as WebSockets ¯\_(ツ)_/¯ + null => WebSocketCloseReason.Unknown, + + // We are not expecting any other error from HttpForwarder after a successful connection upgrade + // Technically, a user could overwrite the IForwarderErrorFeature, in which case we don't know what's going on + _ => WebSocketCloseReason.Unknown + }; + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var readTask = base.ReadAsync(buffer, cancellationToken); + + if (readTask.IsCompletedSuccessfully) + { + var read = readTask.GetAwaiter().GetResult(); + _readParser.Consume(buffer.Span.Slice(0, read)); + return new ValueTask(read); + } + + return Core(buffer, readTask); + + async ValueTask Core(Memory buffer, ValueTask readTask) + { + var read = await readTask; + _readParser.Consume(buffer.Span.Slice(0, read)); + return read; + } + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + _writeParser.Consume(buffer.Span); + return base.WriteAsync(buffer, cancellationToken); + } + } +} diff --git a/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs b/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs index 95154289e..cc58dc529 100644 --- a/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs +++ b/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs @@ -11,15 +11,16 @@ public static class TelemetryConsumptionExtensions { #if NET /// - /// Registers all telemetry listeners (Proxy, Kestrel, Http, NameResolution, NetSecurity and Sockets). + /// Registers all telemetry listeners (Forwarder, Kestrel, Http, NameResolution, NetSecurity, Sockets and WebSockets). /// #else /// - /// Registers all telemetry listeners (Proxy and Kestrel). + /// Registers all telemetry listeners (Forwarder, Kestrel and WebSockets). /// #endif public static IServiceCollection AddTelemetryListeners(this IServiceCollection services) { + services.AddHostedService(); services.AddHostedService(); services.AddHostedService(); #if NET @@ -38,40 +39,46 @@ public static IServiceCollection AddTelemetryConsumer(this IServiceCollection se { var implementsAny = false; - if (consumer is IForwarderTelemetryConsumer) + if (consumer is IWebSocketsTelemetryConsumer webSocketsTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IForwarderTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(webSocketsTelemetryConsumer)); implementsAny = true; } - if (consumer is IKestrelTelemetryConsumer) + if (consumer is IForwarderTelemetryConsumer forwarderTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IKestrelTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(forwarderTelemetryConsumer)); + implementsAny = true; + } + + if (consumer is IKestrelTelemetryConsumer kestrelTelemetryConsumer) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton(kestrelTelemetryConsumer)); implementsAny = true; } #if NET - if (consumer is IHttpTelemetryConsumer) + if (consumer is IHttpTelemetryConsumer httpTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IHttpTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(httpTelemetryConsumer)); implementsAny = true; } - if (consumer is INameResolutionTelemetryConsumer) + if (consumer is INameResolutionTelemetryConsumer nameResolutionTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(INameResolutionTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(nameResolutionTelemetryConsumer)); implementsAny = true; } - if (consumer is INetSecurityTelemetryConsumer) + if (consumer is INetSecurityTelemetryConsumer netSecurityTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(INetSecurityTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(netSecurityTelemetryConsumer)); implementsAny = true; } - if (consumer is ISocketsTelemetryConsumer) + if (consumer is ISocketsTelemetryConsumer socketsTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(ISocketsTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(socketsTelemetryConsumer)); implementsAny = true; } #endif @@ -94,6 +101,12 @@ public static IServiceCollection AddTelemetryConsumer(this IServiceCo { var implementsAny = false; + if (typeof(IWebSocketsTelemetryConsumer).IsAssignableFrom(typeof(TConsumer))) + { + services.AddSingleton(services => (IWebSocketsTelemetryConsumer)services.GetRequiredService()); + implementsAny = true; + } + if (typeof(IForwarderTelemetryConsumer).IsAssignableFrom(typeof(TConsumer))) { services.AddSingleton(services => (IForwarderTelemetryConsumer)services.GetRequiredService()); diff --git a/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs b/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs new file mode 100644 index 000000000..798c5a3ea --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Yarp.Telemetry.Consumption +{ + /// + /// A consumer of Yarp.ReverseProxy.WebSockets EventSource events. + /// + public interface IWebSocketsTelemetryConsumer + { + /// + /// Called when a WebSockets connection is closed. + /// + /// Timestamp when the event was fired. + /// Timestamp when the connection upgrade completed. + /// The reason the WebSocket connection closed. + /// Messages read by the destination server. + /// Messages sent by the destination server. + void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten); + } +} diff --git a/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs b/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs new file mode 100644 index 000000000..0c3d987f9 --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Yarp.Telemetry.Consumption +{ + /// + /// The reason the WebSocket connection closed. + /// + public enum WebSocketCloseReason : int + { + Unknown, + ClientGracefulClose, + ServerGracefulClose, + ClientDisconnect, + ServerDisconnect, + } +} diff --git a/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs b/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs new file mode 100644 index 000000000..fc54c0ee9 --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.Tracing; +using Microsoft.Extensions.Logging; + +namespace Yarp.Telemetry.Consumption +{ + internal interface IWebSocketsMetricsConsumer { } + + internal sealed class WebSocketsEventListenerService : EventListenerService + { + protected override string EventSourceName => "Yarp.ReverseProxy.WebSockets"; + + public WebSocketsEventListenerService(ILogger logger, IEnumerable telemetryConsumers, IEnumerable metricsConsumers) + : base(logger, telemetryConsumers, metricsConsumers) + { } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + const int MinEventId = 1; + const int MaxEventId = 1; + + if (eventData.EventId < MinEventId || eventData.EventId > MaxEventId) + { + return; + } + + if (TelemetryConsumers is null) + { + return; + } + +#pragma warning disable IDE0007 // Use implicit type + // Explicit type here to drop the object? signature of payload elements + ReadOnlyCollection payload = eventData.Payload!; +#pragma warning restore IDE0007 // Use implicit type + + switch (eventData.EventId) + { + case 1: + Debug.Assert(eventData.EventName == "WebSocketClosed" && payload.Count == 4); + { + var establishedTime = new DateTime((long)payload[0]); + var closeReason = (WebSocketCloseReason)payload[1]; + var messagesRead = (long)payload[2]; + var messagesWritten = (long)payload[3]; + foreach (var consumer in TelemetryConsumers) + { + consumer.OnWebSocketClosed(eventData.TimeStamp, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + break; + } + } + } +} diff --git a/test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs b/test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs index 9099d5cb9..de65bfbdf 100644 --- a/test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs +++ b/test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs @@ -22,6 +22,7 @@ public class TestEnvironment { private readonly Action _configureDestinationServices; private readonly Action _configureDestinationApp; + private readonly Action _configureProxyServices; private readonly Action _configureProxy; private readonly Action _configureProxyApp; private readonly HttpProtocols _proxyProtocol; @@ -44,6 +45,7 @@ public TestEnvironment( { destinationApp.Run(destinationGetDelegate); }, + configureProxyServices: null, configureProxy, configureProxyApp, proxyProtocol, @@ -55,7 +57,7 @@ public TestEnvironment( public TestEnvironment( Action configureDestinationServices, Action configureDestinationApp, - Action configureProxy, Action configureProxyApp, + Action configureProxyServices, Action configureProxy, Action configureProxyApp, HttpProtocols proxyProtocol = HttpProtocols.Http1AndHttp2, bool useHttpsOnDestination = false, bool useHttpsOnProxy = false, Encoding headerEncoding = null, Func configTransformer = null) @@ -64,6 +66,7 @@ public TestEnvironment( _configureDestinationApp = configureDestinationApp; _configureProxy = configureProxy; _configureProxyApp = configureProxyApp; + _configureProxyServices = configureProxyServices ?? (_ => { }); _proxyProtocol = proxyProtocol; _useHttpsOnDestination = useHttpsOnDestination; _useHttpsOnProxy = useHttpsOnProxy; @@ -76,7 +79,7 @@ public async Task Invoke(Func clientFunc, CancellationToken cancel using var destination = CreateHost(HttpProtocols.Http1AndHttp2, _useHttpsOnDestination, _headerEncoding, _configureDestinationServices, _configureDestinationApp); await destination.StartAsync(cancellationToken); - using var proxy = CreateProxy(_proxyProtocol, _useHttpsOnDestination, _useHttpsOnProxy, _headerEncoding, ClusterId, destination.GetAddress(), _configureProxy, _configureProxyApp, _configTransformer); + using var proxy = CreateProxy(_proxyProtocol, _useHttpsOnDestination, _useHttpsOnProxy, _headerEncoding, ClusterId, destination.GetAddress(), _configureProxyServices, _configureProxy, _configureProxyApp, _configTransformer); await proxy.StartAsync(cancellationToken); try @@ -91,11 +94,13 @@ public async Task Invoke(Func clientFunc, CancellationToken cancel } public static IHost CreateProxy(HttpProtocols protocols, bool useHttpsOnDestination, bool httpsOnProxy, Encoding requestHeaderEncoding, string clusterId, string destinationAddress, - Action configureProxy, Action configureProxyApp, Func configTransformer) + Action configureServices, Action configureProxy, Action configureProxyApp, Func configTransformer) { return CreateHost(protocols, httpsOnProxy, requestHeaderEncoding, services => { + configureServices(services); + var route = new RouteConfig { RouteId = "route1", diff --git a/test/ReverseProxy.FunctionalTests/HeaderTests.cs b/test/ReverseProxy.FunctionalTests/HeaderTests.cs index 1fe588b31..3981c380a 100644 --- a/test/ReverseProxy.FunctionalTests/HeaderTests.cs +++ b/test/ReverseProxy.FunctionalTests/HeaderTests.cs @@ -340,6 +340,7 @@ public async Task ProxyAsync_ResponseWithEncodedHeaderValue(string headerValue, Exception unhandledError = null; using var proxy = TestEnvironment.CreateProxy(HttpProtocols.Http1, false, false, encoding, "cluster1", $"http://{tcpListener.LocalEndpoint}", + proxyServices => { }, proxyBuilder => { }, proxyApp => { diff --git a/test/ReverseProxy.FunctionalTests/WebSocketTests.cs b/test/ReverseProxy.FunctionalTests/WebSocketTests.cs index 7983f310c..373cf56b0 100644 --- a/test/ReverseProxy.FunctionalTests/WebSocketTests.cs +++ b/test/ReverseProxy.FunctionalTests/WebSocketTests.cs @@ -165,6 +165,7 @@ private static TestEnvironment CreateTestEnvironment(bool forceUpgradable = fals builder.Map("/post", Post); }); }, + proxyServices => { }, proxyBuilder => { }, proxyApp => { diff --git a/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs b/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs new file mode 100644 index 000000000..333a5c9d5 --- /dev/null +++ b/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Net.Http; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Xunit; +using Yarp.ReverseProxy.Common; +using Yarp.ReverseProxy.Utilities; +using Yarp.Telemetry.Consumption; +using Yarp.Tests.Common; + +namespace Yarp.ReverseProxy +{ + public class WebSocketsTelemetryTests + { + [Fact] + public async Task NoWebSocketsUpgrade_NoTelemetryWritten() + { + var telemetry = await TestAsync( + async uri => + { + using var client = new HttpClient(); + await client.GetStringAsync(uri); + }, + (context, webSocket) => throw new InvalidOperationException("Shouldn't be reached")); + + Assert.Null(telemetry); + } + + [Theory] + [InlineData(0, 0, 42)] + [InlineData(0, 1, 42)] + [InlineData(1, 0, 42)] + [InlineData(23, 29, 0)] + [InlineData(17, 19, 1)] + [InlineData(11, 13, 100)] + [InlineData(5, 7, 1_000)] + [InlineData(2, 3, 100_000)] + public async Task MessagesExchanged_CorrectNumberReported(int read, int written, int messageSize) + { + var telemetry = await TestAsync( + async uri => + { + using var client = new ClientWebSocket(); + await client.ConnectAsync(uri, CancellationToken.None); + var webSocket = new WebSocketAdapter(client); + + await Task.WhenAll( + SendMessagesAndCloseAsync(webSocket, read, messageSize), + ReceiveAllMessagesAsync(webSocket)); + }, + async (context, webSocket) => + { + await Task.WhenAll( + SendMessagesAndCloseAsync(webSocket, written, messageSize), + ReceiveAllMessagesAsync(webSocket)); + }, + new ManualClock(new TimeSpan(42))); + + Assert.NotNull(telemetry); + Assert.Equal(42, telemetry!.EstablishedTime.Ticks); + Assert.Contains(telemetry.CloseReason, new[] { WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose }); + Assert.Equal(read, telemetry!.MessagesRead); + Assert.Equal(written, telemetry.MessagesWritten); + } + + public enum Behavior + { + ClosesConnection = 1, + SendsClose_WaitsForClose = 2, + SendsClose_ClosesConnection = 4 | ClosesConnection, + WaitsForClose_SendsClose = 8, + WaitsForClose_ClosesConnection = 16 | ClosesConnection, + } + + [Theory] + // Both sides close the connection - race between which is noticed first + [InlineData(Behavior.ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.Unknown, WebSocketCloseReason.ClientDisconnect, WebSocketCloseReason.ServerDisconnect)] + // One side sends a graceful close + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose)] + // One side sends a graceful close while the other disconnects - race between which is noticed first + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose, WebSocketCloseReason.ClientDisconnect)] + // One side closes the connection while the other is waiting for messages + [InlineData(Behavior.ClosesConnection, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.ClosesConnection, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.ServerDisconnect)] + // Graceful, mutual close - other side closes as a reaction to receiving close + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose)] + // Graceful, mutual close - both sides close at the same time - race between which is noticed first + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + public async Task ConnectionClosed_BlameAttributedCorrectly(Behavior clientBehavior, Behavior serverBehavior, params WebSocketCloseReason[] expectedReasons) + { + var telemetry = await TestAsync( + async uri => + { + using var client = new ClientWebSocket(); + + // Keep sending messages from the client in order to observe a server disconnect sooner + client.Options.KeepAliveInterval = TimeSpan.FromMilliseconds(10); + + await client.ConnectAsync(uri, CancellationToken.None); + var webSocket = new WebSocketAdapter(client); + + try + { + await ProcessAsync(webSocket, clientBehavior, client: client); + } + catch + { + Assert.True(serverBehavior.HasFlag(Behavior.ClosesConnection)); + } + }, + async (context, webSocket) => + { + try + { + await ProcessAsync(webSocket, serverBehavior, context: context); + } + catch + { + Assert.True(clientBehavior.HasFlag(Behavior.ClosesConnection)); + } + }); + + Assert.NotNull(telemetry); + Assert.Contains(telemetry!.CloseReason, expectedReasons); + + static async Task ProcessAsync(WebSocketAdapter webSocket, Behavior behavior, ClientWebSocket? client = null, HttpContext? context = null) + { + if (behavior == Behavior.SendsClose_WaitsForClose || + behavior == Behavior.SendsClose_ClosesConnection) + { + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye"); + } + + if (behavior == Behavior.SendsClose_WaitsForClose || + behavior == Behavior.WaitsForClose_SendsClose || + behavior == Behavior.WaitsForClose_ClosesConnection) + { + await ReceiveAllMessagesAsync(webSocket); + } + + if (behavior == Behavior.WaitsForClose_SendsClose) + { + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye"); + } + + if (behavior.HasFlag(Behavior.ClosesConnection)) + { + client?.Abort(); + + if (context is not null) + { + await context.Response.Body.FlushAsync(); + context.Abort(); + } + } + } + } + + [Theory] + [InlineData(100, 200, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(200, 100, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(100, 100, WebSocketCloseReason.ServerGracefulClose)] // Implementation detail + public async Task ConnectionClosed_BlameReliesOnCloseTimes(long clientCloseTime, long serverCloseTime, WebSocketCloseReason expectedCloseReason) + { + var clock = new ManualClock(new TimeSpan(1)); + + var telemetry = await TestAsync( + async uri => + { + using var client = new ClientWebSocket(); + await client.ConnectAsync(uri, CancellationToken.None); + var webSocket = new WebSocketAdapter(client); + + await ProcessAsync(webSocket, clock, clientCloseTime, sendCloseFirst: clientCloseTime <= serverCloseTime); + }, + async (context, webSocket) => + { + await ProcessAsync(webSocket, clock, serverCloseTime, sendCloseFirst: serverCloseTime < clientCloseTime); + }, + clock); + + Assert.NotNull(telemetry); + Assert.Equal(1, telemetry!.EstablishedTime.Ticks); + Assert.Equal(expectedCloseReason, telemetry.CloseReason); + + static async Task ProcessAsync(WebSocketAdapter webSocket, ManualClock clock, long closeTime, bool sendCloseFirst) + { + var receiveTask = ReceiveAllMessagesAsync(webSocket); + + if (!sendCloseFirst) + { + await receiveTask; + } + + lock (clock) + { + clock.AdvanceClockTo(TimeSpan.FromTicks(closeTime)); + } + + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye", CancellationToken.None); + + await receiveTask; + } + } + + private static async Task ReceiveAllMessagesAsync(WebSocketAdapter webSocket) + { + Memory buffer = new byte[1024]; + + while (true) + { + var result = await webSocket.ReceiveAsync(buffer); + + if (result.MessageType == WebSocketMessageType.Close) + { + break; + } + } + } + + private static async Task SendMessagesAndCloseAsync(WebSocketAdapter webSocket, int messageCount, int messageSize) + { + var rng = new Random(42); + var buffer = new byte[1024]; + + for (var i = 0; i < messageCount; i++) + { + var remaining = messageSize; + + while (remaining > 1) + { + var chunkSize = Math.Min(buffer.Length, remaining - 1); + remaining -= chunkSize; + var chunk = buffer.AsMemory(0, chunkSize); + rng.NextBytes(chunk.Span); + await webSocket.SendAsync(chunk, WebSocketMessageType.Binary, endOfMessage: false); + } + + await webSocket.SendAsync(buffer.AsMemory(0, remaining), WebSocketMessageType.Binary, endOfMessage: true); + } + + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye", CancellationToken.None); + } + + private class WebSocketAdapter + { + private readonly ClientWebSocket? _client; + private readonly WebSocket? _server; + + public WebSocketAdapter(ClientWebSocket? client = null, WebSocket? server = null) + { + Assert.True(client is null ^ server is null); + _client = client; + _server = server; + } + + public ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.ReceiveAsync(buffer, cancellationToken) + : _server!.ReceiveAsync(buffer, cancellationToken); + } + + public ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.SendAsync(buffer, messageType, endOfMessage, cancellationToken) + : _server!.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + + public Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.CloseOutputAsync(closeStatus, statusDescription, cancellationToken) + : _server!.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + } + + private static async Task TestAsync(Func requestDelegate, Func destinationDelegate, IClock? clock = null) + { + var telemetryConsumer = new TelemetryConsumer(); + + var test = new TestEnvironment( + destinationServies => { }, + destinationApp => + { + destinationApp.UseWebSockets(); + + destinationApp.Run(async context => + { + if (context.WebSockets.IsWebSocketRequest) + { + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + + await destinationDelegate(context, new WebSocketAdapter(server: webSocket)); + } + }); + }, + proxyServices => + { + if (clock is not null) + { + proxyServices.AddSingleton(clock); + } + }, + proxyBuilder => + { + proxyBuilder.Services.AddTelemetryConsumer(telemetryConsumer); + }, + proxyApp => + { + proxyApp.UseWebSocketsTelemetry(); + }); + + await test.Invoke(async uri => + { + var webSocketsTarget = uri.Replace("https://", "wss://").Replace("http://", "ws://"); + var webSocketsUri = new Uri(webSocketsTarget, UriKind.Absolute); + + await requestDelegate(webSocketsUri); + }); + + return telemetryConsumer.Telemetry; + } + + private record WebSocketsTelemetry(DateTime Timestamp, DateTime EstablishedTime, WebSocketCloseReason CloseReason, long MessagesRead, long MessagesWritten); + + private class TelemetryConsumer : IWebSocketsTelemetryConsumer + { + public WebSocketsTelemetry? Telemetry { get; private set; } + + public void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + Telemetry = new WebSocketsTelemetry(timestamp, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + } +} diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index 8836f95f9..bdd330174 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -19,6 +19,7 @@ using Moq; using Xunit; using Yarp.Tests.Common; +using Yarp.ReverseProxy.Utilities; namespace Yarp.ReverseProxy.Forwarder.Tests { diff --git a/test/ReverseProxy.Tests/WebSocketsTelemetry/WebSocketsParserTests.cs b/test/ReverseProxy.Tests/WebSocketsTelemetry/WebSocketsParserTests.cs new file mode 100644 index 000000000..eed635a6f --- /dev/null +++ b/test/ReverseProxy.Tests/WebSocketsTelemetry/WebSocketsParserTests.cs @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text; +using Xunit; +using Yarp.ReverseProxy.Utilities; +using Yarp.Tests.Common; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry.Tests +{ + public abstract class WebSocketsParserTests + { + protected abstract bool IsServer { get; } + + private int MaskSize => IsServer ? 4 : 0; + + private WebSocketsParser CreateParser(IClock clock = null) => new(clock ?? new Clock(), IsServer); + + private ReadOnlySpan GetHeader(int opcode, int length, bool endOfMessage = true) + { + var header = new byte[2 + MaskSize + (length < 126 ? 0 : (length < 65536 ? 2 : 8))]; + + Assert.InRange(opcode, 0, 15); + header[0] = (byte)opcode; + + if (endOfMessage) + { + header[0] |= 0x80; + } + + if (length < 126) + { + header[1] = (byte)length; + } + else + { + header[1] = (byte)(length < 65536 ? 126 : 127); + var i = header.Length - MaskSize - 1; + while (length != 0) + { + header[i--] = (byte)(length % 256); + length /= 256; + } + } + + if (IsServer) + { + header[1] |= 0x80; + } + + return header; + } + + private ReadOnlySpan GetCloseFrame(int length = 0) => GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('a', length)), opcode: 8); + + private ReadOnlySpan GetPingFrame(int length = 0) => GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('a', length)), opcode: 9); + + private ReadOnlySpan GetPongFrame(int length = 0) => GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('a', length)), opcode: 10); + + private ReadOnlySpan GetTextMessageFrame(string message, bool continuation = false, bool endOfMessage = true) + { + var messageBytes = Encoding.UTF8.GetBytes(message); + var header = GetHeader(opcode: continuation ? 0 : 1, length: messageBytes.Length, endOfMessage); + + var frame = new byte[header.Length + messageBytes.Length]; + header.CopyTo(frame); + messageBytes.CopyTo(frame, header.Length); + + return frame; + } + + private ReadOnlySpan GetBinaryMessageFrame(ReadOnlySpan message, bool continuation = false, bool endOfMessage = true, int opcode = 2) + { + var header = GetHeader(opcode: continuation ? 0 : opcode, length: message.Length, endOfMessage); + + var frame = new byte[header.Length + message.Length]; + header.CopyTo(frame); + message.CopyTo(frame.AsSpan(header.Length)); + + return frame; + } + + [Fact] + public void CustomClockIsUsedForCloseTime() + { + var clock = new ManualClock(new TimeSpan(42)); + var parser = CreateParser(clock); + + Assert.Null(parser.CloseTime); + + parser.Consume(GetCloseFrame()); + + Assert.NotNull(parser.CloseTime); + Assert.Equal(clock.GetUtcNow(), parser.CloseTime.Value); + } + + [Fact] + public void MessagesAreCountedCorrectly() + { + var parser = CreateParser(); + + // Whole messages + parser.Consume(GetTextMessageFrame("Foo")); + Assert.Equal(1, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(new byte[] { 4, 2 })); + Assert.Equal(2, parser.MessageCount); + + + // Continuations + parser.Consume(GetTextMessageFrame("Hello, ", endOfMessage: false)); + Assert.Equal(2, parser.MessageCount); + + parser.Consume(GetTextMessageFrame("world", continuation: true, endOfMessage: false)); + Assert.Equal(2, parser.MessageCount); + + parser.Consume(GetTextMessageFrame("!", continuation: true, endOfMessage: true)); + Assert.Equal(3, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(new byte[] { 4 }, endOfMessage: false)); + Assert.Equal(3, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(new byte[] { 2 }, continuation: true, endOfMessage: true)); + Assert.Equal(4, parser.MessageCount); + + + // Large messages + parser.Consume(GetTextMessageFrame(new string('a', 1_000))); + Assert.Equal(5, parser.MessageCount); + + parser.Consume(GetTextMessageFrame(new string('b', 100_000))); + Assert.Equal(6, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('c', 1_000)))); + Assert.Equal(7, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('d', 100_000)))); + Assert.Equal(8, parser.MessageCount); + + + // Large messages with continuations + parser.Consume(GetTextMessageFrame(new string('a', 1_000), endOfMessage: false)); + Assert.Equal(8, parser.MessageCount); + + parser.Consume(GetTextMessageFrame(new string('b', 1_000), continuation: true, endOfMessage: true)); + Assert.Equal(9, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('c', 1_000)), endOfMessage: false)); + Assert.Equal(9, parser.MessageCount); + + parser.Consume(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('d', 1_000)), continuation: true, endOfMessage: true)); + Assert.Equal(10, parser.MessageCount); + + + // Fragmented frames + parser.Consume(Array.Empty()); + Assert.Equal(10, parser.MessageCount); + + ConsumeInFragments(ref parser, GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('a', 1_000)))); + Assert.Equal(11, parser.MessageCount); + + var ms = new MemoryStream(); + for (var i = (int)parser.MessageCount; i < 500; i++) + { + // Control frames are not counted + if (i % 7 == 0) + { + ms.Write(GetPingFrame()); + } + if (i % 13 == 0) + { + ms.Write(GetPongFrame()); + } + + switch (i % 4) + { + case 0: + ms.Write(GetTextMessageFrame(new string('a', i))); + break; + + case 1: + ms.Write(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('b', i)))); + break; + + case 2: + ms.Write(GetTextMessageFrame(new string('a', i), endOfMessage: false)); + ms.Write(GetTextMessageFrame(new string('b', i), continuation: true, endOfMessage: false)); + ms.Write(GetTextMessageFrame(new string('c', i), continuation: true, endOfMessage: true)); + break; + + case 3: + ms.Write(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('a', i)), endOfMessage: false)); + ms.Write(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('b', i)), continuation: true, endOfMessage: false)); + ms.Write(GetBinaryMessageFrame(Encoding.UTF8.GetBytes(new string('c', i)), continuation: true, endOfMessage: true)); + break; + } + } + ConsumeInFragments(ref parser, ms.ToArray()); + Assert.Equal(500, parser.MessageCount); + + + // Control frames are not counted + parser.Consume(GetPingFrame()); + parser.Consume(GetPingFrame(length: 10)); + parser.Consume(GetPongFrame()); + parser.Consume(GetPongFrame(length: 10)); + parser.Consume(GetCloseFrame()); + parser.Consume(GetCloseFrame(length: 10)); + Assert.Equal(500, parser.MessageCount); + + + // Messages are still counted after a close frame + parser.Consume(GetTextMessageFrame("Foo")); + Assert.Equal(501, parser.MessageCount); + + static void ConsumeInFragments(ref WebSocketsParser parser, ReadOnlySpan message) + { + var rng = new Random(42); + while (message.Length != 0) + { + var fragmentLength = Math.Min(message.Length, rng.Next(0, 150)); + parser.Consume(message.Slice(0, fragmentLength)); + message = message.Slice(fragmentLength); + } + } + } + } + + public sealed class WebSocketsParserTests_Client : WebSocketsParserTests + { + protected override bool IsServer => false; + } + + public sealed class WebSocketsParserTests_Server : WebSocketsParserTests + { + protected override bool IsServer => true; + } +} diff --git a/test/Tests.Common/ManualClock.cs b/test/Tests.Common/ManualClock.cs index 6be55148b..001066c6a 100644 --- a/test/Tests.Common/ManualClock.cs +++ b/test/Tests.Common/ManualClock.cs @@ -29,7 +29,7 @@ public class ManualClock : IClock private TimeSpan _currentTime; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Initial value for current time. Zero if not specified. public ManualClock(TimeSpan? initialTime = null) @@ -69,7 +69,7 @@ public void AdvanceClockTo(TimeSpan targetTime) _currentTime = targetTime; } - public DateTimeOffset GetUtcNow() => DateTimeOffset.UtcNow; + public DateTimeOffset GetUtcNow() => new DateTime(_currentTime.Ticks, DateTimeKind.Utc); public TimeSpan GetStopwatchTime() => _currentTime;