diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs index 68eaf98be9b8..feaff535ee4e 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs @@ -71,7 +71,7 @@ internal partial class Http2Connection : IHttp2StreamLifetimeHandler, IHttpStrea internal readonly Http2KeepAlive? _keepAlive; internal readonly Dictionary _streams = new Dictionary(); internal PooledStreamStack StreamPool; - internal Action? _onStreamCompleted; + internal IHttp2StreamLifetimeHandler _streamLifetimeHandler; public Http2Connection(HttpConnectionContext context) { @@ -79,6 +79,7 @@ public Http2Connection(HttpConnectionContext context) var http2Limits = httpLimits.Http2; _context = context; + _streamLifetimeHandler = this; // Capture the ExecutionContext before dispatching HTTP/2 middleware. Will be restored by streams when processing request _context.InitialExecutionContext = ExecutionContext.Capture(); @@ -753,7 +754,7 @@ private Http2StreamContext CreateHttp2StreamContext() _context.LocalEndPoint, _context.RemoteEndPoint, _incomingFrame.StreamId, - streamLifetimeHandler: this, + _streamLifetimeHandler, _clientSettings, _serverSettings, _frameWriter, @@ -1230,7 +1231,6 @@ void IHttp2StreamLifetimeHandler.OnStreamCompleted(Http2Stream stream) { _completedStreams.Enqueue(stream); _streamCompletionAwaitable.Complete(); - _onStreamCompleted?.Invoke(stream); } private void UpdateCompletedStreams() diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs index cf1b6fdc7f40..ced453867fe7 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs @@ -564,16 +564,13 @@ await InitializeConnectionAsync(async context => throw new InvalidOperationException("Put the stream into an invalid state by throwing after writing to response."); }); - var streamCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _connection._onStreamCompleted = _ => streamCompletedTcs.TrySetResult(); - await StartStreamAsync(1, _browserRequestHeaders, endStream: true); var stream = _connection._streams[1]; serverTcs.SetResult(); // Wait for the stream to be completed - await streamCompletedTcs.Task; + await WaitForStreamAsync(stream.StreamId).DefaultTimeout(); // TriggerTick will trigger the stream to be returned to the pool so we can assert it TriggerTick(); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs index b7d6b7a47535..cd2150c83b5c 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs @@ -461,6 +461,7 @@ protected void CreateConnection() timeoutControl: _mockTimeoutControl.Object); _connection = new Http2Connection(httpConnectionContext); + _connection._streamLifetimeHandler = new LifetimeHandlerInterceptor(_connection._streamLifetimeHandler, this); var httpConnection = new HttpConnection(httpConnectionContext); httpConnection.Initialize(_connection); @@ -470,6 +471,35 @@ protected void CreateConnection() _timeoutControl.Initialize(_serviceContext.SystemClock.UtcNow.Ticks); } + private class LifetimeHandlerInterceptor : IHttp2StreamLifetimeHandler + { + private readonly IHttp2StreamLifetimeHandler _inner; + private readonly Http2TestBase _httpTestBase; + + public LifetimeHandlerInterceptor(IHttp2StreamLifetimeHandler inner, Http2TestBase httpTestBase) + { + _inner = inner; + _httpTestBase = httpTestBase; + } + + public void DecrementActiveClientStreamCount() + { + _inner.DecrementActiveClientStreamCount(); + } + + public void OnStreamCompleted(Http2Stream stream) + { + _inner.OnStreamCompleted(stream); + + // Stream in test might not have been started with StartStream method. + // In that case there isn't a record of a running stream. + if (_httpTestBase._runningStreams.TryGetValue(stream.StreamId, out var tcs)) + { + tcs.TrySetResult(); + } + } + } + protected void InitializeConnectionWithoutPreface(RequestDelegate application) { if (_connection == null) @@ -681,6 +711,11 @@ protected Task SendHeadersWithPaddingAndPriorityAsync(int streamId, IEnumerable< return FlushAsync(writableBuffer); } + protected Task WaitForStreamAsync(int streamId) + { + return _runningStreams[streamId].Task; + } + protected Task WaitForAllStreamsAsync() { return Task.WhenAll(_runningStreams.Values.Select(tcs => tcs.Task)).DefaultTimeout();