Skip to content

Raise HTTP/2 stream end event in tests with lifetime handler #41230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ internal partial class Http2Connection : IHttp2StreamLifetimeHandler, IHttpStrea
internal readonly Http2KeepAlive? _keepAlive;
internal readonly Dictionary<int, Http2Stream> _streams = new Dictionary<int, Http2Stream>();
internal PooledStreamStack<Http2Stream> StreamPool;
internal Action<Http2Stream>? _onStreamCompleted;
internal IHttp2StreamLifetimeHandler _streamLifetimeHandler;

public Http2Connection(HttpConnectionContext context)
{
var httpLimits = context.ServiceContext.ServerOptions.Limits;
var http2Limits = httpLimits.Http2;

_context = context;
_streamLifetimeHandler = this;

// Capture the ExecutionContext before dispatching HTTP/2 middleware. Will be restored by streams when processing request
_context.InitialExecutionContext = ExecutionContext.Capture();
Expand Down Expand Up @@ -753,7 +754,7 @@ private Http2StreamContext CreateHttp2StreamContext()
_context.LocalEndPoint,
_context.RemoteEndPoint,
_incomingFrame.StreamId,
streamLifetimeHandler: this,
_streamLifetimeHandler,
_clientSettings,
_serverSettings,
_frameWriter,
Expand Down Expand Up @@ -1230,7 +1231,6 @@ void IHttp2StreamLifetimeHandler.OnStreamCompleted(Http2Stream stream)
{
_completedStreams.Enqueue(stream);
_streamCompletionAwaitable.Complete();
_onStreamCompleted?.Invoke(stream);
}

private void UpdateCompletedStreams()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,16 +564,13 @@ await InitializeConnectionAsync(async context =>
throw new InvalidOperationException("Put the stream into an invalid state by throwing after writing to response.");
});

var streamCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_connection._onStreamCompleted = _ => streamCompletedTcs.TrySetResult();

await StartStreamAsync(1, _browserRequestHeaders, endStream: true);

var stream = _connection._streams[1];
serverTcs.SetResult();

// Wait for the stream to be completed
await streamCompletedTcs.Task;
await WaitForStreamAsync(stream.StreamId).DefaultTimeout();

// TriggerTick will trigger the stream to be returned to the pool so we can assert it
TriggerTick();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ protected void CreateConnection()
timeoutControl: _mockTimeoutControl.Object);

_connection = new Http2Connection(httpConnectionContext);
_connection._streamLifetimeHandler = new LifetimeHandlerInterceptor(_connection._streamLifetimeHandler, this);

var httpConnection = new HttpConnection(httpConnectionContext);
httpConnection.Initialize(_connection);
Expand All @@ -470,6 +471,35 @@ protected void CreateConnection()
_timeoutControl.Initialize(_serviceContext.SystemClock.UtcNow.Ticks);
}

private class LifetimeHandlerInterceptor : IHttp2StreamLifetimeHandler
{
private readonly IHttp2StreamLifetimeHandler _inner;
private readonly Http2TestBase _httpTestBase;

public LifetimeHandlerInterceptor(IHttp2StreamLifetimeHandler inner, Http2TestBase httpTestBase)
{
_inner = inner;
_httpTestBase = httpTestBase;
}

public void DecrementActiveClientStreamCount()
{
_inner.DecrementActiveClientStreamCount();
}

public void OnStreamCompleted(Http2Stream stream)
{
_inner.OnStreamCompleted(stream);

// Stream in test might not have been started with StartStream method.
// In that case there isn't a record of a running stream.
if (_httpTestBase._runningStreams.TryGetValue(stream.StreamId, out var tcs))
{
tcs.TrySetResult();
}
}
}

protected void InitializeConnectionWithoutPreface(RequestDelegate application)
{
if (_connection == null)
Expand Down Expand Up @@ -681,6 +711,11 @@ protected Task SendHeadersWithPaddingAndPriorityAsync(int streamId, IEnumerable<
return FlushAsync(writableBuffer);
}

protected Task WaitForStreamAsync(int streamId)
{
return _runningStreams[streamId].Task;
}

protected Task WaitForAllStreamsAsync()
{
return Task.WhenAll(_runningStreams.Values.Select(tcs => tcs.Task)).DefaultTimeout();
Expand Down