diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs index 649560ce78d3..dca44693a026 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs @@ -37,8 +37,9 @@ public partial class Http1Connection : HttpProtocol, IRequestProcessor private int _remainingRequestHeadersBytesAllowed; public Http1Connection(HttpConnectionContext context) - : base(context) { + Initialize(context); + _context = context; _parser = ServiceContext.HttpParser; _keepAliveTicks = ServerOptions.Limits.KeepAliveTimeout.Ticks; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 72bedf1be684..adb1f2fc1862 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -62,12 +62,12 @@ public abstract partial class HttpProtocol : IHttpResponseControl private long _responseBytesWritten; - private readonly HttpConnectionContext _context; + private HttpConnectionContext _context; protected string _methodText = null; private string _scheme = null; - public HttpProtocol(HttpConnectionContext context) + protected void Initialize(HttpConnectionContext context) { _context = context; @@ -90,7 +90,7 @@ public HttpProtocol(HttpConnectionContext context) protected IKestrelTrace Log => ServiceContext.Log; private DateHeaderValueManager DateHeaderValueManager => ServiceContext.DateHeaderValueManager; // Hold direct reference to ServerOptions since this is used very often in the request processing path - protected KestrelServerOptions ServerOptions { get; } + protected KestrelServerOptions ServerOptions { get; set; } protected string ConnectionId => _context.ConnectionId; public string ConnectionIdFeature { get; set; } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs index 9a043b703607..95036a99740b 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs @@ -61,6 +61,10 @@ private enum PseudoHeaderFields private static readonly byte[] _trailersBytes = Encoding.ASCII.GetBytes("trailers"); private static readonly byte[] _connectBytes = Encoding.ASCII.GetBytes("CONNECT"); + // Since the number of streams per connection is user configurable, we need a max just in case that number is too big + // 200 seems reasonable since the default is 100 + private static readonly int _maxPooledStreams = 200; + private readonly HttpConnectionContext _context; private readonly Http2FrameWriter _frameWriter; private readonly HPackDecoder _hpackDecoder; @@ -71,6 +75,8 @@ private enum PseudoHeaderFields private readonly Http2PeerSettings _clientSettings = new Http2PeerSettings(); private readonly Http2Frame _incomingFrame = new Http2Frame(); + private readonly Http2Stream[] _streamPool; + private int _pooledStreamCount; private Http2Stream _currentHeadersStream; private RequestHeaderParsingState _requestHeaderParsingState; @@ -114,6 +120,9 @@ public Http2Connection(HttpConnectionContext context) _serverSettings.HeaderTableSize = (uint)http2Limits.HeaderTableSize; _serverSettings.MaxHeaderListSize = (uint)httpLimits.MaxRequestHeadersTotalSize; _serverSettings.InitialWindowSize = (uint)http2Limits.InitialStreamWindowSize; + + // Pool the set of streams on this connection + _streamPool = new Http2Stream[Math.Min(_maxPooledStreams, http2Limits.MaxStreamsPerConnection)]; } public string ConnectionId => _context.ConnectionId; @@ -194,6 +203,58 @@ public void StopProcessingNextRequest(bool sendGracefulGoAway = false) } } + private Http2Stream CreateStream(IHttpApplication application) + { + Http2Stream stream = null; + + lock (_streamPool) + { + if (_pooledStreamCount > 0) + { + _pooledStreamCount--; + stream = (Http2Stream)_streamPool[_pooledStreamCount]; + } + } + + if (stream == null) + { + stream = new Http2Stream(); + } + + stream.HttpApplication = application; + stream.Initialize(new Http2StreamContext + { + ConnectionId = ConnectionId, + StreamId = _incomingFrame.StreamId, + ServiceContext = _context.ServiceContext, + ConnectionFeatures = _context.ConnectionFeatures, + MemoryPool = _context.MemoryPool, + LocalEndPoint = _context.LocalEndPoint, + RemoteEndPoint = _context.RemoteEndPoint, + StreamLifetimeHandler = this, + ClientPeerSettings = _clientSettings, + ServerPeerSettings = _serverSettings, + FrameWriter = _frameWriter, + ConnectionInputFlowControl = _inputFlowControl, + ConnectionOutputFlowControl = _outputFlowControl, + TimeoutControl = TimeoutControl, + }); + + return stream; + } + + private void ReturnStream(Http2Stream stream) + { + lock (_streamPool) + { + if (_pooledStreamCount < _streamPool.Length) + { + _streamPool[_pooledStreamCount] = stream; + _pooledStreamCount++; + } + } + } + public async Task ProcessRequestsAsync(IHttpApplication application) { Exception error = null; @@ -453,7 +514,7 @@ private Task ProcessFrameAsync(IHttpApplication application, case Http2FrameType.WINDOW_UPDATE: return ProcessWindowUpdateFrameAsync(); case Http2FrameType.CONTINUATION: - return ProcessContinuationFrameAsync(application, payload); + return ProcessContinuationFrameAsync(payload); default: return ProcessUnknownFrameAsync(); } @@ -605,29 +666,13 @@ private Task ProcessHeadersFrameAsync(IHttpApplication appli } // Start a new stream - _currentHeadersStream = new Http2Stream(new Http2StreamContext - { - ConnectionId = ConnectionId, - StreamId = _incomingFrame.StreamId, - ServiceContext = _context.ServiceContext, - ConnectionFeatures = _context.ConnectionFeatures, - MemoryPool = _context.MemoryPool, - LocalEndPoint = _context.LocalEndPoint, - RemoteEndPoint = _context.RemoteEndPoint, - StreamLifetimeHandler = this, - ClientPeerSettings = _clientSettings, - ServerPeerSettings = _serverSettings, - FrameWriter = _frameWriter, - ConnectionInputFlowControl = _inputFlowControl, - ConnectionOutputFlowControl = _outputFlowControl, - TimeoutControl = TimeoutControl, - }); + _currentHeadersStream = CreateStream(application); _currentHeadersStream.Reset(); _headerFlags = _incomingFrame.HeadersFlags; var headersPayload = payload.Slice(0, _incomingFrame.HeadersPayloadLength); // Minus padding - return DecodeHeadersAsync(application, _incomingFrame.HeadersEndHeaders, headersPayload); + return DecodeHeadersAsync(_incomingFrame.HeadersEndHeaders, headersPayload); } } } @@ -878,7 +923,7 @@ private Task ProcessWindowUpdateFrameAsync() return Task.CompletedTask; } - private Task ProcessContinuationFrameAsync(IHttpApplication application, ReadOnlySequence payload) + private Task ProcessContinuationFrameAsync(ReadOnlySequence payload) { if (_currentHeadersStream == null) { @@ -905,7 +950,7 @@ private Task ProcessContinuationFrameAsync(IHttpApplication TimeoutControl.CancelTimeout(); } - return DecodeHeadersAsync(application, _incomingFrame.ContinuationEndHeaders, payload); + return DecodeHeadersAsync(_incomingFrame.ContinuationEndHeaders, payload); } } } @@ -921,7 +966,7 @@ private Task ProcessUnknownFrameAsync() } // This is always called with the _stateLock acquired. - private Task DecodeHeadersAsync(IHttpApplication application, bool endHeaders, ReadOnlySequence payload) + private Task DecodeHeadersAsync(bool endHeaders, ReadOnlySequence payload) { try { @@ -932,7 +977,7 @@ private Task DecodeHeadersAsync(IHttpApplication application { if (_state != Http2ConnectionState.Closed) { - StartStream(application); + StartStream(); } ResetRequestHeaderParsingState(); @@ -969,7 +1014,7 @@ private Task DecodeTrailersAsync(bool endHeaders, ReadOnlySequence payload return Task.CompletedTask; } - private void StartStream(IHttpApplication application) + private void StartStream() { if (!_isMethodConnect && (_parsedPseudoHeaderFields & _mandatoryRequestPseudoHeaderFields) != _mandatoryRequestPseudoHeaderFields) { @@ -995,13 +1040,9 @@ private void StartStream(IHttpApplication application) _activeStreamCount++; _streams[_incomingFrame.StreamId] = _currentHeadersStream; + // Must not allow app code to block the connection handling loop. - ThreadPool.UnsafeQueueUserWorkItem(state => - { - var (app, currentStream) = (Tuple, Http2Stream>)state; - _ = currentStream.ProcessRequestsAsync(app); - }, - new Tuple, Http2Stream>(application, _currentHeadersStream)); + ThreadPool.UnsafeQueueUserWorkItem(_currentHeadersStream, preferLocal: false); } private void ResetRequestHeaderParsingState() @@ -1059,7 +1100,9 @@ void IHttp2StreamLifetimeHandler.OnStreamCompleted(int streamId) } else { - _streams.TryRemove(streamId, out _); + _streams.TryRemove(streamId, out stream); + + ReturnStream(stream); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index da9bbc9dfed8..e74e33d237df 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.IO; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -17,21 +18,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { - public partial class Http2Stream : HttpProtocol + public abstract partial class Http2Stream : HttpProtocol, IThreadPoolWorkItem { - private readonly Http2StreamContext _context; - private readonly Http2OutputProducer _http2Output; - private readonly StreamInputFlowControl _inputFlowControl; - private readonly StreamOutputFlowControl _outputFlowControl; + private Http2StreamContext _context; + private Http2OutputProducer _http2Output; + private StreamInputFlowControl _inputFlowControl; + private StreamOutputFlowControl _outputFlowControl; internal long DrainExpirationTicks { get; set; } private StreamCompletionFlags _completionState; private readonly object _completionLock = new object(); - public Http2Stream(Http2StreamContext context) - : base(context) + public virtual void Initialize(Http2StreamContext context) { + base.Initialize(context); + _context = context; _inputFlowControl = new StreamInputFlowControl( @@ -502,6 +504,11 @@ private Pipe CreateRequestBodyPipe(uint windowSize) } } + /// + /// Used to kick off the request processing loop by derived classes. + /// + public abstract void Execute(); + [Flags] private enum StreamCompletionFlags { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs new file mode 100644 index 000000000000..16be550fa974 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Hosting.Server; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2Stream : Http2Stream + { + public IHttpApplication HttpApplication { get; set; } + + public override void Execute() + { + // REVIEW: Should we store this in a field for easy debugging? + _ = ProcessRequestsAsync(HttpApplication); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs index 001d3952b39d..ded2eec7af16 100644 --- a/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -40,7 +40,8 @@ public HttpProtocolFeatureCollectionTests() _http1Connection.Reset(); _collection = _http1Connection; - var http2Stream = new Http2Stream(context); + var http2Stream = new TestHttp2Stream(); + http2Stream.Initialize(context); http2Stream.Reset(); _http2Collection = http2Stream; } @@ -220,5 +221,12 @@ private int SetFeaturesToNonDefault() } private Http1Connection CreateHttp1Connection() => new TestHttp1Connection(_httpConnectionContext); + + private class TestHttp2Stream : Http2Stream + { + public override void Execute() + { + } + } } }