diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs new file mode 100644 index 000000000000..963fd5d85c84 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs @@ -0,0 +1,542 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + /// + /// http://tools.ietf.org/html/rfc2616#section-3.6.1 + /// + public class Http1ChunkedEncodingMessageBody : Http1MessageBody + { + // byte consts don't have a data type annotation so we pre-cast it + private const byte ByteCR = (byte)'\r'; + // "7FFFFFFF\r\n" is the largest chunk size that could be returned as an int. + private const int MaxChunkPrefixBytes = 10; + + private long _inputLength; + + private Mode _mode = Mode.Prefix; + private volatile bool _canceled; + private Task _pumpTask; + private Pipe _requestBodyPipe; + private ReadResult _readResult; + + public Http1ChunkedEncodingMessageBody(bool keepAlive, Http1Connection context) + : base(context) + { + RequestKeepAlive = keepAlive; + + _requestBodyPipe = CreateRequestBodyPipe(context); + } + + public override void AdvanceTo(SequencePosition consumed) + { + AdvanceTo(consumed, consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length; + _requestBodyPipe.Reader.AdvanceTo(consumed, examined); + OnDataRead(dataLength); + } + + public override bool TryRead(out ReadResult readResult) + { + TryStart(); + + var boolResult = _requestBodyPipe.Reader.TryRead(out _readResult); + + readResult = _readResult; + + if (_readResult.IsCompleted) + { + TryStop(); + } + + return boolResult; + } + + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + TryStart(); + + try + { + var readAwaitable = _requestBodyPipe.Reader.ReadAsync(cancellationToken); + + _readResult = await StartTimingReadAsync(readAwaitable, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } + + StopTimingRead(_readResult.Buffer.Length); + + if (_readResult.IsCompleted) + { + TryStop(); + } + + return _readResult; + } + + public override void Complete(Exception exception) + { + _requestBodyPipe.Reader.Complete(); + _context.ReportApplicationError(exception); + } + + public override void OnWriterCompleted(Action callback, object state) + { + _requestBodyPipe.Reader.OnWriterCompleted(callback, state); + } + + public override void CancelPendingRead() + { + _requestBodyPipe.Reader.CancelPendingRead(); + } + + private async Task PumpAsync() + { + Debug.Assert(!RequestUpgrade, "Upgraded connections should never use this code path!"); + + Exception error = null; + + try + { + var awaitable = _context.Input.ReadAsync(); + + if (!awaitable.IsCompleted) + { + TryProduceContinue(); + } + + while (true) + { + var result = await awaitable; + + if (_context.RequestTimedOut) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); + } + + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.Start; + + try + { + if (_canceled) + { + break; + } + + if (!readableBuffer.IsEmpty) + { + bool done; + done = Read(readableBuffer, _requestBodyPipe.Writer, out consumed, out examined); + + await _requestBodyPipe.Writer.FlushAsync(); + + if (done) + { + break; + } + } + + // Read() will have already have greedily consumed the entire request body if able. + CheckCompletedReadResult(result); + } + finally + { + _context.Input.AdvanceTo(consumed, examined); + } + + awaitable = _context.Input.ReadAsync(); + } + } + catch (Exception ex) + { + error = ex; + } + finally + { + _requestBodyPipe.Writer.Complete(error); + } + } + + protected override Task OnStopAsync() + { + if (!_context.HasStartedConsumingRequestBody) + { + return Task.CompletedTask; + } + + // call complete here on the reader + _requestBodyPipe.Reader.Complete(); + + // PumpTask catches all Exceptions internally. + if (_pumpTask.IsCompleted) + { + // At this point both the request body pipe reader and writer should be completed. + _requestBodyPipe.Reset(); + return Task.CompletedTask; + } + + // Should I call complete here? + return StopAsyncAwaited(); + } + + private async Task StopAsyncAwaited() + { + _canceled = true; + _context.Input.CancelPendingRead(); + await _pumpTask; + + // At this point both the request body pipe reader and writer should be completed. + _requestBodyPipe.Reset(); + } + + protected void Copy(ReadOnlySequence readableBuffer, PipeWriter writableBuffer) + { + if (readableBuffer.IsSingleSegment) + { + writableBuffer.Write(readableBuffer.First.Span); + } + else + { + foreach (var memory in readableBuffer) + { + writableBuffer.Write(memory.Span); + } + } + } + + protected override void OnReadStarted() + { + _pumpTask = PumpAsync(); + } + + protected bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = default; + examined = default; + + while (_mode < Mode.Trailer) + { + if (_mode == Mode.Prefix) + { + ParseChunkedPrefix(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Prefix) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Extension) + { + ParseExtension(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Extension) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Data) + { + ReadChunkedData(readableBuffer, writableBuffer, out consumed, out examined); + + if (_mode == Mode.Data) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Suffix) + { + ParseChunkedSuffix(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Suffix) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + } + + // Chunks finished, parse trailers + if (_mode == Mode.Trailer) + { + ParseChunkedTrailer(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Trailer) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + // _consumedBytes aren't tracked for trailer headers, since headers have separate limits. + if (_mode == Mode.TrailerHeaders) + { + if (_context.TakeMessageHeaders(readableBuffer, out consumed, out examined)) + { + _mode = Mode.Complete; + } + } + + return _mode == Mode.Complete; + } + + private void ParseChunkedPrefix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + var reader = new SequenceReader(buffer); + + if (!reader.TryRead(out var ch1) || !reader.TryRead(out var ch2)) + { + examined = reader.Position; + return; + } + + var chunkSize = CalculateChunkSize(ch1, 0); + ch1 = ch2; + + while (reader.Consumed < MaxChunkPrefixBytes) + { + if (ch1 == ';') + { + consumed = reader.Position; + examined = reader.Position; + + AddAndCheckConsumedBytes(reader.Consumed); + _inputLength = chunkSize; + _mode = Mode.Extension; + return; + } + + if (!reader.TryRead(out ch2)) + { + examined = reader.Position; + return; + } + + if (ch1 == '\r' && ch2 == '\n') + { + consumed = reader.Position; + examined = reader.Position; + + AddAndCheckConsumedBytes(reader.Consumed); + _inputLength = chunkSize; + _mode = chunkSize > 0 ? Mode.Data : Mode.Trailer; + return; + } + + chunkSize = CalculateChunkSize(ch1, chunkSize); + ch1 = ch2; + } + + // At this point, 10 bytes have been consumed which is enough to parse the max value "7FFFFFFF\r\n". + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); + } + + private void ParseExtension(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + // Chunk-extensions not currently parsed + // Just drain the data + consumed = buffer.Start; + examined = buffer.Start; + + do + { + SequencePosition? extensionCursorPosition = buffer.PositionOf(ByteCR); + if (extensionCursorPosition == null) + { + // End marker not found yet + consumed = buffer.End; + examined = buffer.End; + AddAndCheckConsumedBytes(buffer.Length); + return; + }; + + var extensionCursor = extensionCursorPosition.Value; + var charsToByteCRExclusive = buffer.Slice(0, extensionCursor).Length; + + var suffixBuffer = buffer.Slice(extensionCursor); + if (suffixBuffer.Length < 2) + { + consumed = extensionCursor; + examined = buffer.End; + AddAndCheckConsumedBytes(charsToByteCRExclusive); + return; + } + + suffixBuffer = suffixBuffer.Slice(0, 2); + var suffixSpan = suffixBuffer.ToSpan(); + + if (suffixSpan[1] == '\n') + { + // We consumed the \r\n at the end of the extension, so switch modes. + _mode = _inputLength > 0 ? Mode.Data : Mode.Trailer; + + consumed = suffixBuffer.End; + examined = suffixBuffer.End; + AddAndCheckConsumedBytes(charsToByteCRExclusive + 2); + } + else + { + // Don't consume suffixSpan[1] in case it is also a \r. + buffer = buffer.Slice(charsToByteCRExclusive + 1); + consumed = extensionCursor; + AddAndCheckConsumedBytes(charsToByteCRExclusive + 1); + } + } while (_mode == Mode.Extension); + } + + private void ReadChunkedData(ReadOnlySequence buffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + var actual = Math.Min(buffer.Length, _inputLength); + consumed = buffer.GetPosition(actual); + examined = consumed; + + Copy(buffer.Slice(0, actual), writableBuffer); + + _inputLength -= actual; + AddAndCheckConsumedBytes(actual); + + if (_inputLength == 0) + { + _mode = Mode.Suffix; + } + } + + private void ParseChunkedSuffix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + + if (buffer.Length < 2) + { + examined = buffer.End; + return; + } + + var suffixBuffer = buffer.Slice(0, 2); + var suffixSpan = suffixBuffer.ToSpan(); + if (suffixSpan[0] == '\r' && suffixSpan[1] == '\n') + { + consumed = suffixBuffer.End; + examined = suffixBuffer.End; + AddAndCheckConsumedBytes(2); + _mode = Mode.Prefix; + } + else + { + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSuffix); + } + } + + private void ParseChunkedTrailer(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + + if (buffer.Length < 2) + { + examined = buffer.End; + return; + } + + var trailerBuffer = buffer.Slice(0, 2); + var trailerSpan = trailerBuffer.ToSpan(); + + if (trailerSpan[0] == '\r' && trailerSpan[1] == '\n') + { + consumed = trailerBuffer.End; + examined = trailerBuffer.End; + AddAndCheckConsumedBytes(2); + _mode = Mode.Complete; + } + else + { + _mode = Mode.TrailerHeaders; + } + } + + private int CalculateChunkSize(int extraHexDigit, int currentParsedSize) + { + try + { + checked + { + if (extraHexDigit >= '0' && extraHexDigit <= '9') + { + return currentParsedSize * 0x10 + (extraHexDigit - '0'); + } + else if (extraHexDigit >= 'A' && extraHexDigit <= 'F') + { + return currentParsedSize * 0x10 + (extraHexDigit - ('A' - 10)); + } + else if (extraHexDigit >= 'a' && extraHexDigit <= 'f') + { + return currentParsedSize * 0x10 + (extraHexDigit - ('a' - 10)); + } + } + } + catch (OverflowException ex) + { + throw new IOException(CoreStrings.BadRequest_BadChunkSizeData, ex); + } + + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); + return -1; // can't happen, but compiler complains + } + + private enum Mode + { + Prefix, + Extension, + Data, + Suffix, + Trailer, + TrailerHeaders, + Complete + }; + + private Pipe CreateRequestBodyPipe(Http1Connection context) + => new Pipe(new PipeOptions + ( + pool: context.MemoryPool, + readerScheduler: context.ServiceContext.Scheduler, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: 1, + resumeWriterThreshold: 1, + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + )); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs index 2baffd56fd17..f418b475b7aa 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs @@ -44,8 +44,6 @@ public Http1Connection(HttpConnectionContext context) _keepAliveTicks = ServerOptions.Limits.KeepAliveTimeout.Ticks; _requestHeadersTimeoutTicks = ServerOptions.Limits.RequestHeadersTimeout.Ticks; - RequestBodyPipe = CreateRequestBodyPipe(); - _http1Output = new Http1OutputProducer( _context.Transport.Output, _context.ConnectionId, @@ -57,6 +55,7 @@ public Http1Connection(HttpConnectionContext context) Input = _context.Transport.Input; Output = _http1Output; + MemoryPool = _context.MemoryPool; } public PipeReader Input { get; } @@ -67,6 +66,8 @@ public Http1Connection(HttpConnectionContext context) public MinDataRate MinResponseDataRate { get; set; } + public MemoryPool MemoryPool { get; } + protected override void OnRequestProcessingEnded() { Input.Complete(); @@ -531,17 +532,5 @@ protected override bool TryParseRequest(ReadResult result, out bool endConnectio } void IRequestProcessor.Tick(DateTimeOffset now) { } - - private Pipe CreateRequestBodyPipe() - => new Pipe(new PipeOptions - ( - pool: _context.MemoryPool, - readerScheduler: ServiceContext.Scheduler, - writerScheduler: PipeScheduler.Inline, - pauseWriterThreshold: 1, - resumeWriterThreshold: 1, - useSynchronizationContext: false, - minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize - )); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs new file mode 100644 index 000000000000..289ecac406f2 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs @@ -0,0 +1,213 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public class Http1ContentLengthMessageBody : Http1MessageBody + { + private readonly long _contentLength; + private long _inputLength; + private ReadResult _readResult; + private bool _completed; + private int _userCanceled; + + public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context) + : base(context) + { + RequestKeepAlive = keepAlive; + _contentLength = contentLength; + _inputLength = _contentLength; + } + + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + ThrowIfCompleted(); + + if (_inputLength == 0) + { + _readResult = new ReadResult(default, isCanceled: false, isCompleted: true); + return _readResult; + } + + TryStart(); + + // The while(true) loop is required because the Http1 connection calls CancelPendingRead to unblock + // the call to StartTimingReadAsync to check if the request timed out. + // However, if the user called CancelPendingRead, we want that to return a canceled ReadResult + // We internally track an int for that. + while (true) + { + // The issue is that TryRead can get a canceled read result + // which is unknown to StartTimingReadAsync. + if (_context.RequestTimedOut) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); + } + + try + { + var readAwaitable = _context.Input.ReadAsync(cancellationToken); + _readResult = await StartTimingReadAsync(readAwaitable, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } + + if (_context.RequestTimedOut) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); + } + + // Make sure to handle when this is canceled here. + if (_readResult.IsCanceled) + { + if (Interlocked.Exchange(ref _userCanceled, 0) == 1) + { + // Ignore the readResult if it wasn't by the user. + break; + } + else + { + // Reset the timing read here for the next call to read. + StopTimingRead(0); + continue; + } + } + + var readableBuffer = _readResult.Buffer; + var readableBufferLength = readableBuffer.Length; + StopTimingRead(readableBufferLength); + + CheckCompletedReadResult(_readResult); + + if (readableBufferLength > 0) + { + CreateReadResultFromConnectionReadResult(); + + break; + } + } + + return _readResult; + } + + public override bool TryRead(out ReadResult readResult) + { + ThrowIfCompleted(); + + if (_inputLength == 0) + { + readResult = new ReadResult(default, isCanceled: false, isCompleted: true); + return true; + } + + TryStart(); + + if (!_context.Input.TryRead(out _readResult)) + { + readResult = default; + return false; + } + + if (_readResult.IsCanceled) + { + if (Interlocked.Exchange(ref _userCanceled, 0) == 0) + { + // Cancellation wasn't by the user, return default ReadResult + readResult = default; + return false; + } + } + + CreateReadResultFromConnectionReadResult(); + + readResult = _readResult; + + return true; + } + + private void ThrowIfCompleted() + { + if (_completed) + { + throw new InvalidOperationException("Reading is not allowed after the reader was completed."); + } + } + + private void CreateReadResultFromConnectionReadResult() + { + if (_readResult.Buffer.Length > _inputLength) + { + _readResult = new ReadResult(_readResult.Buffer.Slice(0, _inputLength), _readResult.IsCanceled, isCompleted: true); + } + else if (_readResult.Buffer.Length == _inputLength) + { + _readResult = new ReadResult(_readResult.Buffer, _readResult.IsCanceled, isCompleted: true); + } + + if (_readResult.IsCompleted) + { + TryStop(); + } + } + + public override void AdvanceTo(SequencePosition consumed) + { + AdvanceTo(consumed, consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + if (_inputLength == 0) + { + return; + } + + var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length; + + _inputLength -= dataLength; + + _context.Input.AdvanceTo(consumed, examined); + + OnDataRead(dataLength); + } + + protected override void OnReadStarting() + { + if (_contentLength > _context.MaxRequestBodySize) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); + } + } + + public override void Complete(Exception exception) + { + _context.ReportApplicationError(exception); + _completed = true; + } + + public override void OnWriterCompleted(Action callback, object state) + { + // TODO make this work with ContentLength. + } + + public override void CancelPendingRead() + { + Interlocked.Exchange(ref _userCanceled, 1); + _context.Input.CancelPendingRead(); + } + + protected override Task OnStopAsync() + { + Complete(null); + return Task.CompletedTask; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs index fb34300a4bd4..71c346897081 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs @@ -2,24 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; -using System.Diagnostics; -using System.IO; using System.IO.Pipelines; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { public abstract class Http1MessageBody : MessageBody { - private readonly Http1Connection _context; - - private volatile bool _canceled; - private Task _pumpTask; + protected readonly Http1Connection _context; protected Http1MessageBody(Http1Connection context) : base(context, context.MinRequestBodyDataRate) @@ -27,120 +19,28 @@ protected Http1MessageBody(Http1Connection context) _context = context; } - private async Task PumpAsync() - { - Debug.Assert(!RequestUpgrade, "Upgraded connections should never use this code path!"); - - Exception error = null; - - try - { - var awaitable = _context.Input.ReadAsync(); - - if (!awaitable.IsCompleted) - { - TryProduceContinue(); - } - - while (true) - { - var result = await awaitable; - - if (_context.RequestTimedOut) - { - BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); - } - - var readableBuffer = result.Buffer; - var consumed = readableBuffer.Start; - var examined = readableBuffer.Start; - - try - { - if (_canceled) - { - break; - } - - if (!readableBuffer.IsEmpty) - { - bool done; - done = Read(readableBuffer, _context.RequestBodyPipe.Writer, out consumed, out examined); - - await _context.RequestBodyPipe.Writer.FlushAsync(); - - if (done) - { - break; - } - } - - // Read() will have already have greedily consumed the entire request body if able. - if (result.IsCompleted) - { - // OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes - // input completion is observed here before the Input.OnWriterCompleted() callback is fired, - // so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400 - // response is written after observing the unexpected end of request content instead of just - // closing the connection without a response as expected. - _context.OnInputOrOutputCompleted(); - - BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent); - } - } - finally - { - _context.Input.AdvanceTo(consumed, examined); - } - - awaitable = _context.Input.ReadAsync(); - } - } - catch (Exception ex) - { - error = ex; - } - finally - { - _context.RequestBodyPipe.Writer.Complete(error); - } - } - - protected override Task OnStopAsync() + protected void CheckCompletedReadResult(ReadResult result) { - if (!_context.HasStartedConsumingRequestBody) + if (result.IsCompleted) { - return Task.CompletedTask; - } + // OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes + // input completion is observed here before the Input.OnWriterCompleted() callback is fired, + // so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400 + // response is written after observing the unexpected end of request content instead of just + // closing the connection without a response as expected. + _context.OnInputOrOutputCompleted(); - // PumpTask catches all Exceptions internally. - if (_pumpTask.IsCompleted) - { - // At this point both the request body pipe reader and writer should be completed. - _context.RequestBodyPipe.Reset(); - return Task.CompletedTask; + BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent); } - - return StopAsyncAwaited(); - } - - private async Task StopAsyncAwaited() - { - _canceled = true; - _context.Input.CancelPendingRead(); - await _pumpTask; - - // At this point both the request body pipe reader and writer should be completed. - _context.RequestBodyPipe.Reset(); } protected override Task OnConsumeAsync() { try { - if (_context.RequestBodyPipe.Reader.TryRead(out var readResult)) + if (TryRead(out var readResult)) { - _context.RequestBodyPipe.Reader.AdvanceTo(readResult.Buffer.End); + AdvanceTo(readResult.Buffer.End); if (readResult.IsCompleted) { @@ -148,11 +48,6 @@ protected override Task OnConsumeAsync() } } } - catch (OperationCanceledException) - { - // TryRead can throw OperationCanceledException https://github.com/dotnet/corefx/issues/32029 - // because of buggy logic, this works around that for now - } catch (BadHttpRequestException ex) { // At this point, the response has already been written, so this won't result in a 4XX response; @@ -160,11 +55,20 @@ protected override Task OnConsumeAsync() _context.SetBadRequestState(ex); return Task.CompletedTask; } + catch (InvalidOperationException ex) + { + var connectionAbortedException = new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication, ex); + _context.ReportApplicationError(connectionAbortedException); + + // Have to abort the connection because we can't finish draining the request + _context.StopProcessingNextRequest(); + return Task.CompletedTask; + } return OnConsumeAsyncAwaited(); } - private async Task OnConsumeAsyncAwaited() + protected async Task OnConsumeAsyncAwaited() { Log.RequestBodyNotEntirelyRead(_context.ConnectionIdFeature, _context.TraceIdentifier); @@ -175,49 +79,32 @@ private async Task OnConsumeAsyncAwaited() ReadResult result; do { - result = await _context.RequestBodyPipe.Reader.ReadAsync(); - _context.RequestBodyPipe.Reader.AdvanceTo(result.Buffer.End); + result = await ReadAsync(); + AdvanceTo(result.Buffer.End); } while (!result.IsCompleted); } catch (BadHttpRequestException ex) { _context.SetBadRequestState(ex); } - catch (ConnectionAbortedException) + catch (OperationCanceledException ex) when (ex is ConnectionAbortedException || ex is TaskCanceledException) { Log.RequestBodyDrainTimedOut(_context.ConnectionIdFeature, _context.TraceIdentifier); } - finally + catch (InvalidOperationException ex) { - _context.TimeoutControl.CancelTimeout(); - } - } + var connectionAbortedException = new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication, ex); + _context.ReportApplicationError(connectionAbortedException); - protected void Copy(ReadOnlySequence readableBuffer, PipeWriter writableBuffer) - { - if (readableBuffer.IsSingleSegment) - { - writableBuffer.Write(readableBuffer.First.Span); + // Have to abort the connection because we can't finish draining the request + _context.StopProcessingNextRequest(); } - else + finally { - foreach (var memory in readableBuffer) - { - writableBuffer.Write(memory.Span); - } + _context.TimeoutControl.CancelTimeout(); } } - protected override void OnReadStarted() - { - _pumpTask = PumpAsync(); - } - - protected virtual bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) - { - throw new NotImplementedException(); - } - public static MessageBody For( HttpVersion httpVersion, HttpRequestHeaders headers, @@ -242,7 +129,7 @@ public static MessageBody For( BadHttpRequestException.Throw(RequestRejectionReason.UpgradeRequestCannotHavePayload); } - return new ForUpgrade(context); + return new Http1UpgradeMessageBody(context); } if (headers.HasTransferEncoding) @@ -261,7 +148,9 @@ public static MessageBody For( BadHttpRequestException.Throw(RequestRejectionReason.FinalTransferCodingNotChunked, in transferEncoding); } - return new ForChunkedEncoding(keepAlive, context); + // TODO may push more into the wrapper rather than just calling into the message body + // NBD for now. + return new Http1ChunkedEncodingMessageBody(keepAlive, context); } if (headers.ContentLength.HasValue) @@ -273,7 +162,7 @@ public static MessageBody For( return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose; } - return new ForContentLength(keepAlive, contentLength, context); + return new Http1ContentLengthMessageBody(keepAlive, contentLength, context); } // If we got here, request contains no Content-Length or Transfer-Encoding header. @@ -286,459 +175,5 @@ public static MessageBody For( return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose; } - - /// - /// The upgrade stream uses the raw connection stream instead of going through the RequestBodyPipe. This - /// removes the redundant copy from the transport pipe to the body pipe. - /// - private class ForUpgrade : Http1MessageBody - { - public ForUpgrade(Http1Connection context) - : base(context) - { - RequestUpgrade = true; - } - - // This returns IsEmpty so we can avoid draining the body (since it's basically an endless stream) - public override bool IsEmpty => true; - - public override async Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default) - { - while (true) - { - var result = await _context.Input.ReadAsync(cancellationToken); - var readableBuffer = result.Buffer; - var readableBufferLength = readableBuffer.Length; - - try - { - if (!readableBuffer.IsEmpty) - { - foreach (var memory in readableBuffer) - { - // REVIEW: This *could* be slower if 2 things are true - // - The WriteAsync(ReadOnlyMemory) isn't overridden on the destination - // - We change the Kestrel Memory Pool to not use pinned arrays but instead use native memory - await destination.WriteAsync(memory, cancellationToken); - } - } - - if (result.IsCompleted) - { - return; - } - } - finally - { - _context.Input.AdvanceTo(readableBuffer.End); - } - } - } - - public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) - { - while (true) - { - var result = await _context.Input.ReadAsync(cancellationToken); - var readableBuffer = result.Buffer; - var readableBufferLength = readableBuffer.Length; - - var consumed = readableBuffer.End; - var actual = 0; - - try - { - if (readableBufferLength != 0) - { - // buffer.Length is int - actual = (int)Math.Min(readableBufferLength, buffer.Length); - - var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual); - consumed = slice.End; - slice.CopyTo(buffer.Span); - - return actual; - } - - if (result.IsCompleted) - { - return 0; - } - } - finally - { - _context.Input.AdvanceTo(consumed); - } - } - } - - public override Task ConsumeAsync() - { - return Task.CompletedTask; - } - - public override Task StopAsync() - { - return Task.CompletedTask; - } - } - - private class ForContentLength : Http1MessageBody - { - private readonly long _contentLength; - private long _inputLength; - - public ForContentLength(bool keepAlive, long contentLength, Http1Connection context) - : base(context) - { - RequestKeepAlive = keepAlive; - _contentLength = contentLength; - _inputLength = _contentLength; - } - - protected override bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) - { - if (_inputLength == 0) - { - throw new InvalidOperationException("Attempted to read from completed Content-Length request body."); - } - - var actual = (int)Math.Min(readableBuffer.Length, _inputLength); - _inputLength -= actual; - - consumed = readableBuffer.GetPosition(actual); - examined = consumed; - - Copy(readableBuffer.Slice(0, actual), writableBuffer); - - return _inputLength == 0; - } - - protected override void OnReadStarting() - { - if (_contentLength > _context.MaxRequestBodySize) - { - BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); - } - } - } - - /// - /// http://tools.ietf.org/html/rfc2616#section-3.6.1 - /// - private class ForChunkedEncoding : Http1MessageBody - { - // byte consts don't have a data type annotation so we pre-cast it - private const byte ByteCR = (byte)'\r'; - // "7FFFFFFF\r\n" is the largest chunk size that could be returned as an int. - private const int MaxChunkPrefixBytes = 10; - - private long _inputLength; - - private Mode _mode = Mode.Prefix; - - public ForChunkedEncoding(bool keepAlive, Http1Connection context) - : base(context) - { - RequestKeepAlive = keepAlive; - } - - protected override bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) - { - consumed = default(SequencePosition); - examined = default(SequencePosition); - - while (_mode < Mode.Trailer) - { - if (_mode == Mode.Prefix) - { - ParseChunkedPrefix(readableBuffer, out consumed, out examined); - - if (_mode == Mode.Prefix) - { - return false; - } - - readableBuffer = readableBuffer.Slice(consumed); - } - - if (_mode == Mode.Extension) - { - ParseExtension(readableBuffer, out consumed, out examined); - - if (_mode == Mode.Extension) - { - return false; - } - - readableBuffer = readableBuffer.Slice(consumed); - } - - if (_mode == Mode.Data) - { - ReadChunkedData(readableBuffer, writableBuffer, out consumed, out examined); - - if (_mode == Mode.Data) - { - return false; - } - - readableBuffer = readableBuffer.Slice(consumed); - } - - if (_mode == Mode.Suffix) - { - ParseChunkedSuffix(readableBuffer, out consumed, out examined); - - if (_mode == Mode.Suffix) - { - return false; - } - - readableBuffer = readableBuffer.Slice(consumed); - } - } - - // Chunks finished, parse trailers - if (_mode == Mode.Trailer) - { - ParseChunkedTrailer(readableBuffer, out consumed, out examined); - - if (_mode == Mode.Trailer) - { - return false; - } - - readableBuffer = readableBuffer.Slice(consumed); - } - - // _consumedBytes aren't tracked for trailer headers, since headers have separate limits. - if (_mode == Mode.TrailerHeaders) - { - if (_context.TakeMessageHeaders(readableBuffer, out consumed, out examined)) - { - _mode = Mode.Complete; - } - } - - return _mode == Mode.Complete; - } - - private void ParseChunkedPrefix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) - { - consumed = buffer.Start; - examined = buffer.Start; - var reader = new SequenceReader(buffer); - - if (!reader.TryRead(out var ch1) || !reader.TryRead(out var ch2)) - { - examined = reader.Position; - return; - } - - var chunkSize = CalculateChunkSize(ch1, 0); - ch1 = ch2; - - while (reader.Consumed < MaxChunkPrefixBytes) - { - if (ch1 == ';') - { - consumed = reader.Position; - examined = reader.Position; - - AddAndCheckConsumedBytes(reader.Consumed); - _inputLength = chunkSize; - _mode = Mode.Extension; - return; - } - - if (!reader.TryRead(out ch2)) - { - examined = reader.Position; - return; - } - - if (ch1 == '\r' && ch2 == '\n') - { - consumed = reader.Position; - examined = reader.Position; - - AddAndCheckConsumedBytes(reader.Consumed); - _inputLength = chunkSize; - _mode = chunkSize > 0 ? Mode.Data : Mode.Trailer; - return; - } - - chunkSize = CalculateChunkSize(ch1, chunkSize); - ch1 = ch2; - } - - // At this point, 10 bytes have been consumed which is enough to parse the max value "7FFFFFFF\r\n". - BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); - } - - private void ParseExtension(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) - { - // Chunk-extensions not currently parsed - // Just drain the data - consumed = buffer.Start; - examined = buffer.Start; - - do - { - SequencePosition? extensionCursorPosition = buffer.PositionOf(ByteCR); - if (extensionCursorPosition == null) - { - // End marker not found yet - consumed = buffer.End; - examined = buffer.End; - AddAndCheckConsumedBytes(buffer.Length); - return; - }; - - var extensionCursor = extensionCursorPosition.Value; - var charsToByteCRExclusive = buffer.Slice(0, extensionCursor).Length; - - var suffixBuffer = buffer.Slice(extensionCursor); - if (suffixBuffer.Length < 2) - { - consumed = extensionCursor; - examined = buffer.End; - AddAndCheckConsumedBytes(charsToByteCRExclusive); - return; - } - - suffixBuffer = suffixBuffer.Slice(0, 2); - var suffixSpan = suffixBuffer.ToSpan(); - - if (suffixSpan[1] == '\n') - { - // We consumed the \r\n at the end of the extension, so switch modes. - _mode = _inputLength > 0 ? Mode.Data : Mode.Trailer; - - consumed = suffixBuffer.End; - examined = suffixBuffer.End; - AddAndCheckConsumedBytes(charsToByteCRExclusive + 2); - } - else - { - // Don't consume suffixSpan[1] in case it is also a \r. - buffer = buffer.Slice(charsToByteCRExclusive + 1); - consumed = extensionCursor; - AddAndCheckConsumedBytes(charsToByteCRExclusive + 1); - } - } while (_mode == Mode.Extension); - } - - private void ReadChunkedData(ReadOnlySequence buffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) - { - var actual = Math.Min(buffer.Length, _inputLength); - consumed = buffer.GetPosition(actual); - examined = consumed; - - Copy(buffer.Slice(0, actual), writableBuffer); - - _inputLength -= actual; - AddAndCheckConsumedBytes(actual); - - if (_inputLength == 0) - { - _mode = Mode.Suffix; - } - } - - private void ParseChunkedSuffix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) - { - consumed = buffer.Start; - examined = buffer.Start; - - if (buffer.Length < 2) - { - examined = buffer.End; - return; - } - - var suffixBuffer = buffer.Slice(0, 2); - var suffixSpan = suffixBuffer.ToSpan(); - if (suffixSpan[0] == '\r' && suffixSpan[1] == '\n') - { - consumed = suffixBuffer.End; - examined = suffixBuffer.End; - AddAndCheckConsumedBytes(2); - _mode = Mode.Prefix; - } - else - { - BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSuffix); - } - } - - private void ParseChunkedTrailer(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) - { - consumed = buffer.Start; - examined = buffer.Start; - - if (buffer.Length < 2) - { - examined = buffer.End; - return; - } - - var trailerBuffer = buffer.Slice(0, 2); - var trailerSpan = trailerBuffer.ToSpan(); - - if (trailerSpan[0] == '\r' && trailerSpan[1] == '\n') - { - consumed = trailerBuffer.End; - examined = trailerBuffer.End; - AddAndCheckConsumedBytes(2); - _mode = Mode.Complete; - } - else - { - _mode = Mode.TrailerHeaders; - } - } - - private int CalculateChunkSize(int extraHexDigit, int currentParsedSize) - { - try - { - checked - { - if (extraHexDigit >= '0' && extraHexDigit <= '9') - { - return currentParsedSize * 0x10 + (extraHexDigit - '0'); - } - else if (extraHexDigit >= 'A' && extraHexDigit <= 'F') - { - return currentParsedSize * 0x10 + (extraHexDigit - ('A' - 10)); - } - else if (extraHexDigit >= 'a' && extraHexDigit <= 'f') - { - return currentParsedSize * 0x10 + (extraHexDigit - ('a' - 10)); - } - } - } - catch (OverflowException ex) - { - throw new IOException(CoreStrings.BadRequest_BadChunkSizeData, ex); - } - - BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); - return -1; // can't happen, but compiler complains - } - - private enum Mode - { - Prefix, - Extension, - Data, - Suffix, - Trailer, - TrailerHeaders, - Complete - }; - } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs new file mode 100644 index 000000000000..1fcf18c37c43 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs @@ -0,0 +1,82 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + /// + /// The upgrade stream uses the raw connection stream instead of going through the RequestBodyPipe. This + /// removes the redundant copy from the transport pipe to the body pipe. + /// + public class Http1UpgradeMessageBody : Http1MessageBody + { + public bool _completed; + public Http1UpgradeMessageBody(Http1Connection context) + : base(context) + { + RequestUpgrade = true; + } + + // This returns IsEmpty so we can avoid draining the body (since it's basically an endless stream) + public override bool IsEmpty => true; + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + if (_completed) + { + throw new InvalidOperationException("Reading is not allowed after the reader was completed."); + } + return _context.Input.ReadAsync(cancellationToken); + } + + public override bool TryRead(out ReadResult result) + { + if (_completed) + { + throw new InvalidOperationException("Reading is not allowed after the reader was completed."); + } + return _context.Input.TryRead(out result); + } + + public override void AdvanceTo(SequencePosition consumed) + { + _context.Input.AdvanceTo(consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + _context.Input.AdvanceTo(consumed, examined); + } + + public override void Complete(Exception exception) + { + // Don't call Connection.Complete. + _context.ReportApplicationError(exception); + _completed = true; + } + + public override void CancelPendingRead() + { + _context.Input.CancelPendingRead(); + } + + public override void OnWriterCompleted(Action callback, object state) + { + _context.Input.OnWriterCompleted(callback, state); + } + + public override Task ConsumeAsync() + { + return Task.CompletedTask; + } + + public override Task StopAsync() + { + return Task.CompletedTask; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs index 4289a46b883e..9be0b8a6dc2d 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -20,6 +20,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public partial class HttpProtocol : IHttpRequestFeature, IHttpResponseFeature, IResponseBodyPipeFeature, + IRequestBodyPipeFeature, IHttpUpgradeFeature, IHttpConnectionFeature, IHttpRequestLifetimeFeature, @@ -94,8 +95,39 @@ IHeaderDictionary IHttpRequestFeature.Headers Stream IHttpRequestFeature.Body { - get => RequestBody; - set => RequestBody = value; + get + { + return RequestBody; + } + set + { + RequestBody = value; + var requestPipeReader = new StreamPipeReader(RequestBody, new StreamPipeReaderOptions( + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize, + minimumReadThreshold: KestrelMemoryPool.MinimumSegmentSize / 4, + _context.MemoryPool)); + RequestBodyPipeReader = requestPipeReader; + + // The StreamPipeWrapper needs to be disposed as it hold onto blocks of memory + if (_wrapperObjectsToDispose == null) + { + _wrapperObjectsToDispose = new List(); + } + _wrapperObjectsToDispose.Add(requestPipeReader); + } + } + + PipeReader IRequestBodyPipeFeature.RequestBodyPipe + { + get + { + return RequestBodyPipeReader; + } + set + { + RequestBodyPipeReader = value; + RequestBody = new ReadOnlyPipeStream(RequestBodyPipeReader); + } } int IHttpResponseFeature.StatusCode @@ -275,7 +307,7 @@ async Task IHttpUpgradeFeature.UpgradeAsync() await FlushAsync(); - return _streams.Upgrade(); + return bodyControl.Upgrade(); } void IHttpRequestLifetimeFeature.Abort() diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs index 8e5fbc44aab1..aec1e2318f4b 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs @@ -16,6 +16,7 @@ public partial class HttpProtocol : IFeatureCollection private static readonly Type IHttpRequestFeatureType = typeof(IHttpRequestFeature); private static readonly Type IHttpResponseFeatureType = typeof(IHttpResponseFeature); private static readonly Type IResponseBodyPipeFeatureType = typeof(IResponseBodyPipeFeature); + private static readonly Type IRequestBodyPipeFeatureType = typeof(IRequestBodyPipeFeature); private static readonly Type IHttpRequestIdentifierFeatureType = typeof(IHttpRequestIdentifierFeature); private static readonly Type IServiceProvidersFeatureType = typeof(IServiceProvidersFeature); private static readonly Type IHttpRequestLifetimeFeatureType = typeof(IHttpRequestLifetimeFeature); @@ -41,6 +42,7 @@ public partial class HttpProtocol : IFeatureCollection private object _currentIHttpRequestFeature; private object _currentIHttpResponseFeature; private object _currentIResponseBodyPipeFeature; + private object _currentIRequestBodyPipeFeature; private object _currentIHttpRequestIdentifierFeature; private object _currentIServiceProvidersFeature; private object _currentIHttpRequestLifetimeFeature; @@ -72,6 +74,7 @@ private void FastReset() _currentIHttpRequestFeature = this; _currentIHttpResponseFeature = this; _currentIResponseBodyPipeFeature = this; + _currentIRequestBodyPipeFeature = this; _currentIHttpUpgradeFeature = this; _currentIHttpRequestIdentifierFeature = this; _currentIHttpRequestLifetimeFeature = this; @@ -160,6 +163,10 @@ object IFeatureCollection.this[Type key] { feature = _currentIResponseBodyPipeFeature; } + else if (key == IRequestBodyPipeFeatureType) + { + feature = _currentIRequestBodyPipeFeature; + } else if (key == IHttpRequestIdentifierFeatureType) { feature = _currentIHttpRequestIdentifierFeature; @@ -268,6 +275,10 @@ object IFeatureCollection.this[Type key] { _currentIResponseBodyPipeFeature = value; } + else if (key == IRequestBodyPipeFeatureType) + { + _currentIRequestBodyPipeFeature = value; + } else if (key == IHttpRequestIdentifierFeatureType) { _currentIHttpRequestIdentifierFeature = value; @@ -374,6 +385,10 @@ TFeature IFeatureCollection.Get() { feature = (TFeature)_currentIResponseBodyPipeFeature; } + else if (typeof(TFeature) == typeof(IRequestBodyPipeFeature)) + { + feature = (TFeature)_currentIRequestBodyPipeFeature; + } else if (typeof(TFeature) == typeof(IHttpRequestIdentifierFeature)) { feature = (TFeature)_currentIHttpRequestIdentifierFeature; @@ -486,6 +501,10 @@ void IFeatureCollection.Set(TFeature feature) { _currentIResponseBodyPipeFeature = feature; } + else if (typeof(TFeature) == typeof(IRequestBodyPipeFeature)) + { + _currentIRequestBodyPipeFeature = feature; + } else if (typeof(TFeature) == typeof(IHttpRequestIdentifierFeature)) { _currentIHttpRequestIdentifierFeature = feature; @@ -590,6 +609,10 @@ private IEnumerable> FastEnumerable() { yield return new KeyValuePair(IResponseBodyPipeFeatureType, _currentIResponseBodyPipeFeature); } + if (_currentIRequestBodyPipeFeature != null) + { + yield return new KeyValuePair(IRequestBodyPipeFeatureType, _currentIRequestBodyPipeFeature); + } if (_currentIHttpRequestIdentifierFeature != null) { yield return new KeyValuePair(IHttpRequestIdentifierFeatureType, _currentIHttpRequestIdentifierFeature); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index d560f9a1eebb..724aa8514ecc 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -29,8 +29,7 @@ public abstract partial class HttpProtocol : IDefaultHttpContextContainer, IHttp private static readonly byte[] _bytesTransferEncodingChunked = Encoding.ASCII.GetBytes("\r\nTransfer-Encoding: chunked"); private static readonly byte[] _bytesServer = Encoding.ASCII.GetBytes("\r\nServer: " + Constants.ServerName); - protected Streams _streams; - private HttpResponsePipeWriter _originalPipeWriter; + protected BodyControl bodyControl; private Stack, object>> _onStarting; private Stack, object>> _onCompleted; @@ -75,8 +74,6 @@ public HttpProtocol(HttpConnectionContext context) public IHttpResponseControl HttpResponseControl { get; set; } - public Pipe RequestBodyPipe { get; protected set; } - public ServiceContext ServiceContext => _context.ServiceContext; private IPEndPoint LocalEndPoint => _context.LocalEndPoint; private IPEndPoint RemoteEndPoint => _context.RemoteEndPoint; @@ -193,6 +190,7 @@ private void HttpVersionSetSlow(string value) public IHeaderDictionary RequestHeaders { get; set; } public Stream RequestBody { get; set; } + public PipeReader RequestBodyPipeReader { get; set; } private int _statusCode; public int StatusCode @@ -293,20 +291,17 @@ DefaultHttpContext IDefaultHttpContextContainer.HttpContext } } - public void InitializeStreams(MessageBody messageBody) + public void InitializeBodyControl(MessageBody messageBody) { - if (_streams == null) + if (bodyControl == null) { - var pipeWriter = new HttpResponsePipeWriter(this); - _streams = new Streams(bodyControl: this, pipeWriter); - _originalPipeWriter = pipeWriter; + bodyControl = new BodyControl(bodyControl: this, this); } - (RequestBody, ResponseBody) = _streams.Start(messageBody); - ResponsePipeWriter = _originalPipeWriter; + (RequestBody, ResponseBody, RequestBodyPipeReader, ResponsePipeWriter) = bodyControl.Start(messageBody); } - public void StopStreams() => _streams.Stop(); + public void StopBodies() => bodyControl.Stop(); // For testing internal void ResetState() @@ -460,7 +455,7 @@ protected void AbortRequest() protected void PoisonRequestBodyStream(Exception abortReason) { - _streams?.Abort(abortReason); + bodyControl?.Abort(abortReason); } // Prevents the RequestAborted token from firing for the duration of the request. @@ -566,7 +561,7 @@ private async Task ProcessRequests(IHttpApplication applicat IsUpgradableRequest = messageBody.RequestUpgrade; - InitializeStreams(messageBody); + InitializeBodyControl(messageBody); var context = application.CreateContext(this); @@ -608,7 +603,7 @@ private async Task ProcessRequests(IHttpApplication applicat // At this point all user code that needs use to the request or response streams has completed. // Using these streams in the OnCompleted callback is not allowed. - StopStreams(); + StopBodies(); // 4XX responses are written by TryProduceInvalidRequestResponse during connection tear down. if (_requestRejectedException == null) @@ -652,9 +647,6 @@ private async Task ProcessRequests(IHttpApplication applicat if (HasStartedConsumingRequestBody) { - RequestBodyPipe.Reader.Complete(); - - // Wait for Http1MessageBody.PumpAsync() to call RequestBodyPipe.Writer.Complete(). await messageBody.StopAsync(); } } @@ -695,7 +687,6 @@ protected Task FireOnStarting() { return FireOnStartingMayAwait(onStarting); } - } private Task FireOnStartingMayAwait(Stack, object>> onStarting) @@ -1250,8 +1241,14 @@ public void SetBadRequestState(BadHttpRequestException ex) _requestRejectedException = ex; } - protected void ReportApplicationError(Exception ex) + public void ReportApplicationError(Exception ex) { + // ReportApplicationError can be called with a null exception from MessageBody + if (ex == null) + { + return; + } + if (_applicationException == null) { _applicationException = ex; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestPipeReader.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestPipeReader.cs new file mode 100644 index 000000000000..b7bce1a005c6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestPipeReader.cs @@ -0,0 +1,134 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + /// + /// Default HttpRequest PipeReader implementation to be used by Kestrel. + /// + public class HttpRequestPipeReader : PipeReader + { + private MessageBody _body; + private HttpStreamState _state; + private Exception _error; + + public HttpRequestPipeReader() + { + _state = HttpStreamState.Closed; + } + + public override void AdvanceTo(SequencePosition consumed) + { + ValidateState(); + + _body.AdvanceTo(consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + ValidateState(); + + _body.AdvanceTo(consumed, examined); + } + + public override void CancelPendingRead() + { + ValidateState(); + + _body.CancelPendingRead(); + } + + public override void Complete(Exception exception = null) + { + ValidateState(); + + _body.Complete(exception); + } + + public override void OnWriterCompleted(Action callback, object state) + { + ValidateState(); + + _body.OnWriterCompleted(callback, state); + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + ValidateState(cancellationToken); + + return _body.ReadAsync(cancellationToken); + } + + public override bool TryRead(out ReadResult result) + { + ValidateState(); + + return _body.TryRead(out result); + } + + public void StartAcceptingReads(MessageBody body) + { + // Only start if not aborted + if (_state == HttpStreamState.Closed) + { + _state = HttpStreamState.Open; + _body = body; + } + } + + public void StopAcceptingReads() + { + // Can't use dispose (or close) as can be disposed too early by user code + // As exampled in EngineTests.ZeroContentLengthNotSetAutomaticallyForCertainStatusCodes + _state = HttpStreamState.Closed; + _body = null; + } + + public void Abort(Exception error = null) + { + // We don't want to throw an ODE until the app func actually completes. + // If the request is aborted, we throw a TaskCanceledException instead, + // unless error is not null, in which case we throw it. + if (_state != HttpStreamState.Closed) + { + _state = HttpStreamState.Aborted; + _error = error; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ValidateState(CancellationToken cancellationToken = default) + { + var state = _state; + if (state == HttpStreamState.Open) + { + cancellationToken.ThrowIfCancellationRequested(); + } + else if (state == HttpStreamState.Closed) + { + ThrowObjectDisposedException(); + } + else + { + if (_error != null) + { + ExceptionDispatchInfo.Capture(_error).Throw(); + } + else + { + ThrowTaskCanceledException(); + } + } + + void ThrowObjectDisposedException() => throw new ObjectDisposedException(nameof(HttpRequestStream)); + void ThrowTaskCanceledException() => throw new TaskCanceledException(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs index 31d73b248187..587fdc55c060 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs @@ -3,154 +3,51 @@ using System; using System.IO; -using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { - internal class HttpRequestStream : ReadOnlyStream + internal class HttpRequestStream : ReadOnlyPipeStream { + private HttpRequestPipeReader _pipeReader; private readonly IHttpBodyControlFeature _bodyControl; - private MessageBody _body; - private HttpStreamState _state; - private Exception _error; - public HttpRequestStream(IHttpBodyControlFeature bodyControl) + public HttpRequestStream(IHttpBodyControlFeature bodyControl, HttpRequestPipeReader pipeReader) + : base (pipeReader) { _bodyControl = bodyControl; - _state = HttpStreamState.Closed; - } - - public override bool CanSeek => false; - - public override long Length - => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override void Flush() - { - } - - public override Task FlushAsync(CancellationToken cancellationToken) - { - return Task.CompletedTask; - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - public override void SetLength(long value) - { - throw new NotSupportedException(); - } - - public override int Read(byte[] buffer, int offset, int count) - { - if (!_bodyControl.AllowSynchronousIO) - { - throw new InvalidOperationException(CoreStrings.SynchronousReadsDisallowed); - } - - return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { - var task = ReadAsync(buffer, offset, count, default, state); - if (callback != null) - { - task.ContinueWith(t => callback.Invoke(t)); - } - return task; - } - - public override int EndRead(IAsyncResult asyncResult) - { - return ((Task)asyncResult).GetAwaiter().GetResult(); - } - - private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) - { - var tcs = new TaskCompletionSource(state); - var task = ReadAsync(buffer, offset, count, cancellationToken); - task.ContinueWith((task2, state2) => - { - var tcs2 = (TaskCompletionSource)state2; - if (task2.IsCanceled) - { - tcs2.SetCanceled(); - } - else if (task2.IsFaulted) - { - tcs2.SetException(task2.Exception); - } - else - { - tcs2.SetResult(task2.Result); - } - }, tcs, cancellationToken); - return tcs.Task; + _pipeReader = pipeReader; } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - ValidateState(cancellationToken); - return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); } public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { - ValidateState(cancellationToken); - return ReadAsyncInternal(destination, cancellationToken); } - private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) - { - try - { - return await _body.ReadAsync(buffer, cancellationToken); - } - catch (ConnectionAbortedException ex) - { - throw new TaskCanceledException("The request was aborted", ex); - } - } - - public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + public override int Read(byte[] buffer, int offset, int count) { - if (destination == null) - { - throw new ArgumentNullException(nameof(destination)); - } - if (bufferSize <= 0) + if (!_bodyControl.AllowSynchronousIO) { - throw new ArgumentException(CoreStrings.PositiveNumberRequired, nameof(bufferSize)); + throw new InvalidOperationException(CoreStrings.SynchronousReadsDisallowed); } - ValidateState(cancellationToken); - - return CopyToAsyncInternal(destination, cancellationToken); + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); } - private async Task CopyToAsyncInternal(Stream destination, CancellationToken cancellationToken) + private ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) { try { - await _body.CopyToAsync(destination, cancellationToken); + return base.ReadAsync(buffer, cancellationToken); } catch (ConnectionAbortedException ex) { @@ -158,62 +55,13 @@ private async Task CopyToAsyncInternal(Stream destination, CancellationToken can } } - public void StartAcceptingReads(MessageBody body) - { - // Only start if not aborted - if (_state == HttpStreamState.Closed) - { - _state = HttpStreamState.Open; - _body = body; - } - } - - public void StopAcceptingReads() - { - // Can't use dispose (or close) as can be disposed too early by user code - // As exampled in EngineTests.ZeroContentLengthNotSetAutomaticallyForCertainStatusCodes - _state = HttpStreamState.Closed; - _body = null; - } - - public void Abort(Exception error = null) + public override void Flush() { - // We don't want to throw an ODE until the app func actually completes. - // If the request is aborted, we throw a TaskCanceledException instead, - // unless error is not null, in which case we throw it. - if (_state != HttpStreamState.Closed) - { - _state = HttpStreamState.Aborted; - _error = error; - } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void ValidateState(CancellationToken cancellationToken) + public override Task FlushAsync(CancellationToken cancellationToken) { - var state = _state; - if (state == HttpStreamState.Open) - { - cancellationToken.ThrowIfCancellationRequested(); - } - else if (state == HttpStreamState.Closed) - { - ThrowObjectDisposedException(); - } - else - { - if (_error != null) - { - ExceptionDispatchInfo.Capture(_error).Throw(); - } - else - { - ThrowTaskCanceledException(); - } - } - - void ThrowObjectDisposedException() => throw new ObjectDisposedException(nameof(HttpRequestStream)); - void ThrowTaskCanceledException() => throw new TaskCanceledException(); + return Task.CompletedTask; } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponsePipeWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponsePipeWriter.cs index 0cc5a65fa0b2..2a406b067f86 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponsePipeWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponsePipeWriter.cs @@ -34,6 +34,7 @@ public override void CancelPendingFlush() public override void Complete(Exception exception = null) { + ValidateState(); _pipeControl.Complete(exception); } @@ -57,6 +58,7 @@ public override Span GetSpan(int sizeHint = 0) public override void OnReaderCompleted(Action callback, object state) { + ValidateState(); throw new NotSupportedException(); } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs index 054ffaffa09f..0c9f19075e61 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs @@ -38,20 +38,5 @@ public override void Flush() base.Flush(); } - - public void StartAcceptingWrites() - { - _pipeWriter.StartAcceptingWrites(); - } - - public void StopAcceptingWrites() - { - _pipeWriter.StopAcceptingWrites(); - } - - public void Abort() - { - _pipeWriter.Abort(); - } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs index 5dfe29549795..159daffcd085 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; -using System.IO; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; @@ -13,8 +11,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { public abstract class MessageBody { - private static readonly MessageBody _zeroContentLengthClose = new ForZeroContentLength(keepAlive: false); - private static readonly MessageBody _zeroContentLengthKeepAlive = new ForZeroContentLength(keepAlive: true); + private static readonly MessageBody _zeroContentLengthClose = new ZeroContentLengthMessageBody(keepAlive: false); + private static readonly MessageBody _zeroContentLengthKeepAlive = new ZeroContentLengthMessageBody(keepAlive: true); private readonly HttpProtocol _context; private readonly MinDataRate _minRequestBodyDataRate; @@ -23,9 +21,9 @@ public abstract class MessageBody private long _consumedBytes; private bool _stopped; - private bool _timingEnabled; - private bool _backpressure; - private long _alreadyTimedBytes; + protected bool _timingEnabled; + protected bool _backpressure; + protected long _alreadyTimedBytes; protected MessageBody(HttpProtocol context, MinDataRate minRequestBodyDataRate) { @@ -45,94 +43,19 @@ protected MessageBody(HttpProtocol context, MinDataRate minRequestBodyDataRate) protected IKestrelTrace Log => _context.ServiceContext.Log; - public virtual async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default(CancellationToken)) - { - TryStart(); - - while (true) - { - var result = await StartTimingReadAsync(cancellationToken); - var readableBuffer = result.Buffer; - var readableBufferLength = readableBuffer.Length; - StopTimingRead(readableBufferLength); - - var consumed = readableBuffer.End; - var actual = 0; - - try - { - if (readableBufferLength != 0) - { - // buffer.Length is int - actual = (int)Math.Min(readableBufferLength, buffer.Length); - - // Make sure we don't double-count bytes on the next read. - _alreadyTimedBytes = readableBufferLength - actual; - - var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual); - consumed = slice.End; - slice.CopyTo(buffer.Span); - - return actual; - } + public abstract void AdvanceTo(SequencePosition consumed); - if (result.IsCompleted) - { - TryStop(); - return 0; - } - } - finally - { - _context.RequestBodyPipe.Reader.AdvanceTo(consumed); + public abstract void AdvanceTo(SequencePosition consumed, SequencePosition examined); - // Update the flow-control window after advancing the pipe reader, so we don't risk overfilling - // the pipe despite the client being well-behaved. - OnDataRead(actual); - } - } - } - - public virtual async Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) - { - TryStart(); + public abstract bool TryRead(out ReadResult readResult); - while (true) - { - var result = await StartTimingReadAsync(cancellationToken); - var readableBuffer = result.Buffer; - var readableBufferLength = readableBuffer.Length; - StopTimingRead(readableBufferLength); + public abstract void OnWriterCompleted(Action callback, object state); - try - { - if (readableBufferLength != 0) - { - foreach (var memory in readableBuffer) - { - // REVIEW: This *could* be slower if 2 things are true - // - The WriteAsync(ReadOnlyMemory) isn't overridden on the destination - // - We change the Kestrel Memory Pool to not use pinned arrays but instead use native memory - await destination.WriteAsync(memory, cancellationToken); - } - } + public abstract void Complete(Exception exception); - if (result.IsCompleted) - { - TryStop(); - return; - } - } - finally - { - _context.RequestBodyPipe.Reader.AdvanceTo(readableBuffer.End); + public abstract void CancelPendingRead(); - // Update the flow-control window after advancing the pipe reader, so we don't risk overfilling - // the pipe despite the client being well-behaved. - OnDataRead(readableBufferLength); - } - } - } + public abstract ValueTask ReadAsync(CancellationToken cancellationToken = default); public virtual Task ConsumeAsync() { @@ -161,7 +84,7 @@ protected void TryProduceContinue() } } - private void TryStart() + protected void TryStart() { if (_context.HasStartedConsumingRequestBody) { @@ -185,7 +108,7 @@ private void TryStart() OnReadStarted(); } - private void TryStop() + protected void TryStop() { if (_stopped) { @@ -232,12 +155,13 @@ protected void AddAndCheckConsumedBytes(long consumedBytes) } } - private ValueTask StartTimingReadAsync(CancellationToken cancellationToken) + protected ValueTask StartTimingReadAsync(ValueTask readAwaitable, CancellationToken cancellationToken) { - var readAwaitable = _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); if (!readAwaitable.IsCompleted && _timingEnabled) { + TryProduceContinue(); + _backpressure = true; _context.TimeoutControl.StartTimingRead(); } @@ -245,7 +169,7 @@ private ValueTask StartTimingReadAsync(CancellationToken cancellatio return readAwaitable; } - private void StopTimingRead(long bytesRead) + protected void StopTimingRead(long bytesRead) { _context.TimeoutControl.BytesRead(bytesRead - _alreadyTimedBytes); _alreadyTimedBytes = 0; @@ -256,24 +180,5 @@ private void StopTimingRead(long bytesRead) _context.TimeoutControl.StopTimingRead(); } } - - private class ForZeroContentLength : MessageBody - { - public ForZeroContentLength(bool keepAlive) - : base(null, null) - { - RequestKeepAlive = keepAlive; - } - - public override bool IsEmpty => true; - - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default(CancellationToken)) => new ValueTask(0); - - public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) => Task.CompletedTask; - - public override Task ConsumeAsync() => Task.CompletedTask; - - public override Task StopAsync() => Task.CompletedTask; - } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/ZeroContentLengthMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/ZeroContentLengthMessageBody.cs new file mode 100644 index 000000000000..355f534be02b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/ZeroContentLengthMessageBody.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public class ZeroContentLengthMessageBody : MessageBody + { + public ZeroContentLengthMessageBody(bool keepAlive) + : base(null, null) + { + RequestKeepAlive = keepAlive; + } + + public override bool IsEmpty => true; + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default(CancellationToken)) => new ValueTask(new ReadResult(default, isCanceled: false, isCompleted: true)); + + public override Task ConsumeAsync() => Task.CompletedTask; + + public override Task StopAsync() => Task.CompletedTask; + + public override void AdvanceTo(SequencePosition consumed) { } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { } + + public override bool TryRead(out ReadResult result) + { + result = new ReadResult(default, isCanceled: false, isCompleted: true); + return true; + } + + public override void OnWriterCompleted(Action callback, object state) { } + + public override void Complete(Exception ex) { } + + public override void CancelPendingRead() { } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs index 7427e98e14a5..cba9b491b221 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs @@ -1,15 +1,19 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { public class Http2MessageBody : MessageBody { private readonly Http2Stream _context; + private ReadResult _readResult; private Http2MessageBody(Http2Stream context, MinDataRate minRequestBodyDataRate) : base(context, minRequestBodyDataRate) @@ -51,5 +55,75 @@ public static MessageBody For(Http2Stream context, MinDataRate minRequestBodyDat return new Http2MessageBody(context, minRequestBodyDataRate); } + + public override void AdvanceTo(SequencePosition consumed) + { + AdvanceTo(consumed, consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length; + _context.RequestBodyPipe.Reader.AdvanceTo(consumed, examined); + OnDataRead(dataLength); + } + + public override bool TryRead(out ReadResult readResult) + { + return _context.RequestBodyPipe.Reader.TryRead(out readResult); + } + + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + TryStart(); + + try + { + var readAwaitable = _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); + + _readResult = await StartTimingReadAsync(readAwaitable, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } + + StopTimingRead(_readResult.Buffer.Length); + + if (_readResult.IsCompleted) + { + TryStop(); + } + + return _readResult; + } + + public override void Complete(Exception exception) + { + _context.RequestBodyPipe.Reader.Complete(); + _context.ReportApplicationError(exception); + } + + public override void OnWriterCompleted(Action callback, object state) + { + _context.RequestBodyPipe.Reader.OnWriterCompleted(callback, state); + } + + public override void CancelPendingRead() + { + _context.RequestBodyPipe.Reader.CancelPendingRead(); + } + + protected override Task OnStopAsync() + { + if (!_context.HasStartedConsumingRequestBody) + { + return Task.CompletedTask; + } + + _context.RequestBodyPipe.Reader.Complete(); + + return Task.CompletedTask; + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index 661c2ae06895..e2a5b3b83d2e 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -25,6 +25,8 @@ public abstract partial class Http2Stream : HttpProtocol, IThreadPoolWorkItem private readonly StreamInputFlowControl _inputFlowControl; private readonly StreamOutputFlowControl _outputFlowControl; + public Pipe RequestBodyPipe { get; } + internal long DrainExpirationTicks { get; set; } private StreamCompletionFlags _completionState; diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/BodyControl.cs similarity index 52% rename from src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs rename to src/Servers/Kestrel/Core/src/Internal/Infrastructure/BodyControl.cs index ee75017f6342..5fbeead03e39 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/BodyControl.cs @@ -3,26 +3,34 @@ using System; using System.IO; +using System.IO.Pipelines; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure { - public class Streams + public class BodyControl { private static readonly ThrowingWasUpgradedWriteOnlyStream _throwingResponseStream = new ThrowingWasUpgradedWriteOnlyStream(); private readonly HttpResponseStream _response; + private readonly HttpResponsePipeWriter _responseWriter; + private readonly HttpRequestPipeReader _requestReader; private readonly HttpRequestStream _request; + private readonly HttpRequestPipeReader _emptyRequestReader; private readonly WrappingStream _upgradeableResponse; private readonly HttpRequestStream _emptyRequest; private readonly Stream _upgradeStream; - public Streams(IHttpBodyControlFeature bodyControl, HttpResponsePipeWriter writer) + public BodyControl(IHttpBodyControlFeature bodyControl, IHttpResponseControl responseControl) { - _request = new HttpRequestStream(bodyControl); - _emptyRequest = new HttpRequestStream(bodyControl); - _response = new HttpResponseStream(bodyControl, writer); + _requestReader = new HttpRequestPipeReader(); + _request = new HttpRequestStream(bodyControl, _requestReader); + _emptyRequestReader = new HttpRequestPipeReader(); + _emptyRequest = new HttpRequestStream(bodyControl, _emptyRequestReader); + + _responseWriter = new HttpResponsePipeWriter(responseControl); + _response = new HttpResponseStream(bodyControl, _responseWriter); _upgradeableResponse = new WrappingStream(_response); _upgradeStream = new HttpUpgradeStream(_request, _response); } @@ -35,37 +43,37 @@ public Stream Upgrade() return _upgradeStream; } - public (Stream request, Stream response) Start(MessageBody body) + public (Stream request, Stream response, PipeReader reader, PipeWriter writer) Start(MessageBody body) { - _request.StartAcceptingReads(body); - _emptyRequest.StartAcceptingReads(MessageBody.ZeroContentLengthClose); - _response.StartAcceptingWrites(); + _requestReader.StartAcceptingReads(body); + _emptyRequestReader.StartAcceptingReads(MessageBody.ZeroContentLengthClose); + _responseWriter.StartAcceptingWrites(); if (body.RequestUpgrade) { // until Upgrade() is called, context.Response.Body should use the normal output stream _upgradeableResponse.SetInnerStream(_response); // upgradeable requests should never have a request body - return (_emptyRequest, _upgradeableResponse); + return (_emptyRequest, _upgradeableResponse, _emptyRequestReader, _responseWriter); } else { - return (_request, _response); + return (_request, _response, _requestReader, _responseWriter); } } public void Stop() { - _request.StopAcceptingReads(); - _emptyRequest.StopAcceptingReads(); - _response.StopAcceptingWrites(); + _requestReader.StopAcceptingReads(); + _emptyRequestReader.StopAcceptingReads(); + _responseWriter.StopAcceptingWrites(); } public void Abort(Exception error) { - _request.Abort(error); - _emptyRequest.Abort(error); - _response.Abort(); + _requestReader.Abort(error); + _emptyRequestReader.Abort(error); + _responseWriter.Abort(); } } } diff --git a/src/Servers/Kestrel/Core/test/BodyControlTests.cs b/src/Servers/Kestrel/Core/test/BodyControlTests.cs new file mode 100644 index 000000000000..397b92da856f --- /dev/null +++ b/src/Servers/Kestrel/Core/test/BodyControlTests.cs @@ -0,0 +1,189 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class BodyControlTests + { + [Fact] + public async Task BodyControlThrowAfterAbort() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + var (request, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody()); + + var ex = new Exception("My error"); + bodyControl.Abort(ex); + + await response.WriteAsync(new byte[1], 0, 1); + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + Assert.Same(ex, + await Assert.ThrowsAsync(async () => await requestPipe.ReadAsync())); + } + + [Fact] + public async Task BodyControlThrowOnAbortAfterUpgrade() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + var (request, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody(upgradeable: true)); + + var upgrade = bodyControl.Upgrade(); + var ex = new Exception("My error"); + bodyControl.Abort(ex); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(async () => await requestPipe.ReadAsync())); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + [Fact] + public async Task BodyControlThrowOnUpgradeAfterAbort() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + + var (request, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody(upgradeable: true)); + var ex = new Exception("My error"); + bodyControl.Abort(ex); + + var upgrade = bodyControl.Upgrade(); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + Assert.Same(ex, + await Assert.ThrowsAsync(async () => await requestPipe.ReadAsync())); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + [Fact] + public async Task RequestPipeMethodsThrowAfterAbort() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + + var (_, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody(upgradeable: true)); + var ex = new Exception("My error"); + bodyControl.Abort(ex); + + await response.WriteAsync(new byte[1], 0, 1); + Assert.Same(ex, + Assert.Throws(() => requestPipe.AdvanceTo(new SequencePosition()))); + Assert.Same(ex, + Assert.Throws(() => requestPipe.AdvanceTo(new SequencePosition(), new SequencePosition()))); + Assert.Same(ex, + Assert.Throws(() => requestPipe.CancelPendingRead())); + Assert.Same(ex, + Assert.Throws(() => requestPipe.TryRead(out var res))); + Assert.Same(ex, + Assert.Throws(() => requestPipe.Complete())); + Assert.Same(ex, + Assert.Throws(() => requestPipe.OnWriterCompleted(null, null))); + } + + [Fact] + public async Task RequestPipeThrowsObjectDisposedExceptionAfterStop() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + + var (_, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody()); + + bodyControl.Stop(); + + Assert.Throws(() => requestPipe.AdvanceTo(new SequencePosition())); + Assert.Throws(() => requestPipe.AdvanceTo(new SequencePosition(), new SequencePosition())); + Assert.Throws(() => requestPipe.CancelPendingRead()); + Assert.Throws(() => requestPipe.TryRead(out var res)); + Assert.Throws(() => requestPipe.Complete()); + Assert.Throws(() => requestPipe.OnWriterCompleted(null, null)); + await Assert.ThrowsAsync(async () => await requestPipe.ReadAsync()); + } + + [Fact] + public async Task ResponsePipeThrowsObjectDisposedExceptionAfterStop() + { + var bodyControl = new BodyControl(Mock.Of(), Mock.Of()); + + var (_, response, requestPipe, responsePipe) = bodyControl.Start(new MockMessageBody()); + + bodyControl.Stop(); + + Assert.Throws(() => responsePipe.Advance(1)); + Assert.Throws(() => responsePipe.CancelPendingFlush()); + Assert.Throws(() => responsePipe.GetMemory()); + Assert.Throws(() => responsePipe.GetSpan()); + Assert.Throws(() => responsePipe.Complete()); + Assert.Throws(() => responsePipe.OnReaderCompleted(null, null)); + await Assert.ThrowsAsync(async () => await responsePipe.WriteAsync(new Memory())); + await Assert.ThrowsAsync(async () => await responsePipe.FlushAsync()); + } + + private class MockMessageBody : MessageBody + { + public MockMessageBody(bool upgradeable = false) + : base(null, null) + { + RequestUpgrade = upgradeable; + } + + public override void AdvanceTo(SequencePosition consumed) + { + throw new NotImplementedException(); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + throw new NotImplementedException(); + } + + public override void CancelPendingRead() + { + throw new NotImplementedException(); + } + + public override void Complete(Exception exception) + { + throw new NotImplementedException(); + } + + public override void OnWriterCompleted(Action callback, object state) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override bool TryRead(out ReadResult readResult) + { + throw new NotImplementedException(); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs b/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs index 17263e2f6daa..533c24c9ca4a 100644 --- a/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs +++ b/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs @@ -353,7 +353,7 @@ public void InitializeStreamsResetsStreams() { // Arrange var messageBody = Http1MessageBody.For(Kestrel.Core.Internal.Http.HttpVersion.Http11, (HttpRequestHeaders)_http1Connection.RequestHeaders, _http1Connection); - _http1Connection.InitializeStreams(messageBody); + _http1Connection.InitializeBodyControl(messageBody); var originalRequestBody = _http1Connection.RequestBody; var originalResponseBody = _http1Connection.ResponseBody; @@ -361,7 +361,7 @@ public void InitializeStreamsResetsStreams() _http1Connection.ResponseBody = new MemoryStream(); // Act - _http1Connection.InitializeStreams(messageBody); + _http1Connection.InitializeBodyControl(messageBody); // Assert Assert.Same(originalRequestBody, _http1Connection.RequestBody); diff --git a/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs index 5021fe334c9a..28b1df7c7ab5 100644 --- a/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs @@ -118,6 +118,7 @@ public void FeaturesSetByTypeSameAsGeneric() _collection[typeof(IHttpRequestFeature)] = CreateHttp1Connection(); _collection[typeof(IHttpResponseFeature)] = CreateHttp1Connection(); _collection[typeof(IResponseBodyPipeFeature)] = CreateHttp1Connection(); + _collection[typeof(IRequestBodyPipeFeature)] = CreateHttp1Connection(); _collection[typeof(IHttpRequestIdentifierFeature)] = CreateHttp1Connection(); _collection[typeof(IHttpRequestLifetimeFeature)] = CreateHttp1Connection(); _collection[typeof(IHttpConnectionFeature)] = CreateHttp1Connection(); @@ -138,6 +139,7 @@ public void FeaturesSetByGenericSameAsByType() _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); @@ -176,6 +178,7 @@ private void CompareGenericGetterToIndexer() Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestFeature)]); Assert.Same(_collection.Get(), _collection[typeof(IHttpResponseFeature)]); Assert.Same(_collection.Get(), _collection[typeof(IResponseBodyPipeFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IRequestBodyPipeFeature)]); Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestIdentifierFeature)]); Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestLifetimeFeature)]); Assert.Same(_collection.Get(), _collection[typeof(IHttpConnectionFeature)]); diff --git a/src/Servers/Kestrel/Core/test/HttpRequestPipeReaderTests.cs b/src/Servers/Kestrel/Core/test/HttpRequestPipeReaderTests.cs new file mode 100644 index 000000000000..97f0c91cba7c --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpRequestPipeReaderTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpRequestPipeReaderTests + { + [Fact] + public async Task StopAcceptingReadsCausesReadToThrowObjectDisposedException() + { + var pipeReader = new HttpRequestPipeReader(); + pipeReader.StartAcceptingReads(null); + pipeReader.StopAcceptingReads(); + + // Validation for ReadAsync occurs in an async method in ReadOnlyPipeStream. + await Assert.ThrowsAsync(async () => { await pipeReader.ReadAsync(); }); + } + [Fact] + public async Task AbortCausesReadToCancel() + { + var pipeReader = new HttpRequestPipeReader(); + + pipeReader.StartAcceptingReads(null); + pipeReader.Abort(); + await Assert.ThrowsAsync(() => pipeReader.ReadAsync().AsTask()); + } + + [Fact] + public async Task AbortWithErrorCausesReadToCancel() + { + var pipeReader = new HttpRequestPipeReader(); + + pipeReader.StartAcceptingReads(null); + var error = new Exception(); + pipeReader.Abort(error); + var exception = await Assert.ThrowsAsync(() => pipeReader.ReadAsync().AsTask()); + Assert.Same(error, exception); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs b/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs index ee3c21a042cf..5eaaf242b2c3 100644 --- a/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; @@ -17,49 +18,49 @@ public class HttpRequestStreamTests [Fact] public void CanReadReturnsTrue() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.True(stream.CanRead); } [Fact] public void CanSeekReturnsFalse() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.False(stream.CanSeek); } [Fact] public void CanWriteReturnsFalse() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.False(stream.CanWrite); } [Fact] public void SeekThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); } [Fact] public void LengthThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.Length); } [Fact] public void SetLengthThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.SetLength(0)); } [Fact] public void PositionThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.Position); Assert.Throws(() => stream.Position = 0); } @@ -67,21 +68,21 @@ public void PositionThrows() [Fact] public void WriteThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.Write(new byte[1], 0, 1)); } [Fact] public void WriteByteThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); Assert.Throws(() => stream.WriteByte(0)); } [Fact] public async Task WriteAsyncThrows() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], 0, 1)); } @@ -89,14 +90,14 @@ public async Task WriteAsyncThrows() // Read-only streams should support Flush according to https://github.com/dotnet/corefx/pull/27327#pullrequestreview-98384813 public void FlushDoesNotThrow() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); stream.Flush(); } [Fact] public async Task FlushAsyncDoesNotThrow() { - var stream = new HttpRequestStream(Mock.Of()); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); await stream.FlushAsync(); } @@ -104,13 +105,15 @@ public async Task FlushAsyncDoesNotThrow() public async Task SynchronousReadsThrowIfDisallowedByIHttpBodyControlFeature() { var allowSynchronousIO = false; + var mockBodyControl = new Mock(); mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(() => allowSynchronousIO); var mockMessageBody = new Mock(null, null); - mockMessageBody.Setup(m => m.ReadAsync(It.IsAny>(), CancellationToken.None)).Returns(new ValueTask(0)); + mockMessageBody.Setup(m => m.ReadAsync(CancellationToken.None)).Returns(new ValueTask(new ReadResult(default, isCanceled: false, isCompleted: true))); - var stream = new HttpRequestStream(mockBodyControl.Object); - stream.StartAcceptingReads(mockMessageBody.Object); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(mockBodyControl.Object, pipeReader); + pipeReader.StartAcceptingReads(mockMessageBody.Object); Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1)); @@ -127,75 +130,89 @@ public async Task SynchronousReadsThrowIfDisallowedByIHttpBodyControlFeature() [Fact] public async Task AbortCausesReadToCancel() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); - stream.Abort(); + var pipeReader = new HttpRequestPipeReader(); + + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); + pipeReader.Abort(); await Assert.ThrowsAsync(() => stream.ReadAsync(new byte[1], 0, 1)); } [Fact] public async Task AbortWithErrorCausesReadToCancel() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); + var pipeReader = new HttpRequestPipeReader(); + + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); var error = new Exception(); - stream.Abort(error); + pipeReader.Abort(error); var exception = await Assert.ThrowsAsync(() => stream.ReadAsync(new byte[1], 0, 1)); Assert.Same(error, exception); } [Fact] - public void StopAcceptingReadsCausesReadToThrowObjectDisposedException() + public async Task StopAcceptingReadsCausesReadToThrowObjectDisposedException() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); - stream.StopAcceptingReads(); - Assert.Throws(() => { stream.ReadAsync(new byte[1], 0, 1); }); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); + pipeReader.StopAcceptingReads(); + + // Validation for ReadAsync occurs in an async method in ReadOnlyPipeStream. + await Assert.ThrowsAsync(async () => { await stream.ReadAsync(new byte[1], 0, 1); }); } [Fact] public async Task AbortCausesCopyToAsyncToCancel() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); - stream.Abort(); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); + pipeReader.Abort(); await Assert.ThrowsAsync(() => stream.CopyToAsync(Mock.Of())); } [Fact] public async Task AbortWithErrorCausesCopyToAsyncToCancel() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); var error = new Exception(); - stream.Abort(error); + pipeReader.Abort(error); var exception = await Assert.ThrowsAsync(() => stream.CopyToAsync(Mock.Of())); Assert.Same(error, exception); } [Fact] - public void StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException() + public async Task StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); - stream.StopAcceptingReads(); - Assert.Throws(() => { stream.CopyToAsync(Mock.Of()); }); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); + pipeReader.StopAcceptingReads(); + // Validation for CopyToAsync occurs in an async method in ReadOnlyPipeStream. + await Assert.ThrowsAsync(async () => { await stream.CopyToAsync(Mock.Of()); }); } [Fact] public void NullDestinationCausesCopyToAsyncToThrowArgumentNullException() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), pipeReader); + pipeReader.StartAcceptingReads(null); Assert.Throws(() => { stream.CopyToAsync(null); }); } [Fact] public void ZeroBufferSizeCausesCopyToAsyncToThrowArgumentException() { - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(null); - Assert.Throws(() => { stream.CopyToAsync(Mock.Of(), 0); }); + var pipeReader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), new HttpRequestPipeReader()); + pipeReader.StartAcceptingReads(null); + // This is technically a breaking change, to throw an ArgumentoutOfRangeException rather than an ArgumentException + Assert.Throws(() => { stream.CopyToAsync(Mock.Of(), 0); }); } } } diff --git a/src/Servers/Kestrel/Core/test/HttpResponsePipeWriterTests.cs b/src/Servers/Kestrel/Core/test/HttpResponsePipeWriterTests.cs index 9f57f1dffb32..41a2f743b75e 100644 --- a/src/Servers/Kestrel/Core/test/HttpResponsePipeWriterTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpResponsePipeWriterTests.cs @@ -14,6 +14,7 @@ public class HttpResponsePipeWriterTests public void OnReaderCompletedThrowsNotSupported() { var pipeWriter = CreateHttpResponsePipeWriter(); + pipeWriter.StartAcceptingWrites(); Assert.Throws(() => pipeWriter.OnReaderCompleted((a, b) => { }, null)); } @@ -48,30 +49,31 @@ public void GetSpanAfterStopAcceptingWritesThrowsObjectDisposedException() } [Fact] - public void FlushAsyncAfterStopAcceptingWritesThrowsObjectDisposedException() + public void CompleteAfterStopAcceptingWritesThrowsObjectDisposedException() { var pipeWriter = CreateHttpResponsePipeWriter(); pipeWriter.StartAcceptingWrites(); pipeWriter.StopAcceptingWrites(); - var ex = Assert.Throws(() => { pipeWriter.FlushAsync(); }); + var ex = Assert.Throws(() => { pipeWriter.Complete(); }); Assert.Contains(CoreStrings.WritingToResponseBodyAfterResponseCompleted, ex.Message); } [Fact] - public void WriteAsyncAfterStopAcceptingWritesThrowsObjectDisposedException() + public void FlushAsyncAfterStopAcceptingWritesThrowsObjectDisposedException() { var pipeWriter = CreateHttpResponsePipeWriter(); pipeWriter.StartAcceptingWrites(); pipeWriter.StopAcceptingWrites(); - var ex = Assert.Throws(() => { pipeWriter.WriteAsync(new Memory()); }); + var ex = Assert.Throws(() => { pipeWriter.FlushAsync(); }); Assert.Contains(CoreStrings.WritingToResponseBodyAfterResponseCompleted, ex.Message); } [Fact] - public void CompleteCallsStopAcceptingWrites() + public void WriteAsyncAfterStopAcceptingWritesThrowsObjectDisposedException() { var pipeWriter = CreateHttpResponsePipeWriter(); - pipeWriter.Complete(); + pipeWriter.StartAcceptingWrites(); + pipeWriter.StopAcceptingWrites(); var ex = Assert.Throws(() => { pipeWriter.WriteAsync(new Memory()); }); Assert.Contains(CoreStrings.WritingToResponseBodyAfterResponseCompleted, ex.Message); } diff --git a/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs b/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs index 3f00c58e0acc..4b7e167e9b94 100644 --- a/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs @@ -98,8 +98,8 @@ public void StopAcceptingWritesCausesWriteToThrowObjectDisposedException() { var pipeWriter = new HttpResponsePipeWriter(Mock.Of()); var stream = new HttpResponseStream(Mock.Of(), pipeWriter); - stream.StartAcceptingWrites(); - stream.StopAcceptingWrites(); + pipeWriter.StartAcceptingWrites(); + pipeWriter.StopAcceptingWrites(); var ex = Assert.Throws(() => { stream.WriteAsync(new byte[1], 0, 1); }); Assert.Contains(CoreStrings.WritingToResponseBodyAfterResponseCompleted, ex.Message); } @@ -115,7 +115,7 @@ public async Task SynchronousWritesThrowIfDisallowedByIHttpBodyControlFeature() var pipeWriter = new HttpResponsePipeWriter(mockHttpResponseControl.Object); var stream = new HttpResponseStream(mockBodyControl.Object, pipeWriter); - stream.StartAcceptingWrites(); + pipeWriter.StartAcceptingWrites(); // WriteAsync doesn't throw. await stream.WriteAsync(new byte[1], 0, 1); diff --git a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs index 159f1ef14fb2..785097ab1768 100644 --- a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs +++ b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs @@ -2,19 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.IO; -using System.IO.Pipelines; -using System.Runtime.InteropServices; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; -using Microsoft.AspNetCore.Testing; using Moq; using Xunit; using Xunit.Sdk; @@ -33,8 +29,9 @@ public async Task CanReadFromContentLength(HttpVersion httpVersion) var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); var mockBodyControl = new Mock(); mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); - var stream = new HttpRequestStream(mockBodyControl.Object); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(mockBodyControl.Object, reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -47,7 +44,99 @@ public async Task CanReadFromContentLength(HttpVersion httpVersion) count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanReadFromContentLengthPipeApis(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + + var readResult = await reader.ReadAsync(); + + Assert.Equal(5, readResult.Buffer.Length); + AssertASCII("Hello", readResult.Buffer); + reader.AdvanceTo(readResult.Buffer.End); + + readResult = await reader.ReadAsync(); + Assert.True(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanTryReadFromContentLengthPipeApis(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + Assert.True(reader.TryRead(out var readResult)); + + Assert.Equal(5, readResult.Buffer.Length); + AssertASCII("Hello", readResult.Buffer); + reader.AdvanceTo(readResult.Buffer.End); + + reader.TryRead(out readResult); + Assert.True(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task ReadAsyncWithoutAdvanceFromContentLengthThrows(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + var readResult = await reader.ReadAsync(); + + await Assert.ThrowsAsync(async () => await reader.ReadAsync()); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task TryReadWithoutAdvanceFromContentLengthThrows(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + Assert.True(reader.TryRead(out var readResult)); + + Assert.Throws(() => reader.TryRead(out readResult)); + await body.StopAsync(); } } @@ -60,8 +149,10 @@ public async Task CanReadAsyncFromContentLength(HttpVersion httpVersion) using (var input = new TestInput()) { var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -74,7 +165,6 @@ public async Task CanReadAsyncFromContentLength(HttpVersion httpVersion) count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -87,8 +177,9 @@ public async Task CanReadFromChunkedEncoding() var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); var mockBodyControl = new Mock(); mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); - var stream = new HttpRequestStream(mockBodyControl.Object); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(mockBodyControl.Object, reader); + reader.StartAcceptingReads(body); input.Add("5\r\nHello\r\n"); @@ -103,7 +194,6 @@ public async Task CanReadFromChunkedEncoding() count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -114,8 +204,9 @@ public async Task CanReadAsyncFromChunkedEncoding() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("5\r\nHello\r\n"); @@ -130,7 +221,6 @@ public async Task CanReadAsyncFromChunkedEncoding() count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -141,8 +231,9 @@ public async Task ReadExitsGivenIncompleteChunkedExtension() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("5;\r\0"); @@ -154,9 +245,16 @@ public async Task ReadExitsGivenIncompleteChunkedExtension() input.Add("\r\r\r\nHello\r\n0\r\n\r\n"); Assert.Equal(5, await readTask.DefaultTimeout()); - Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + try + { + var res = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(0, res); + } + catch (Exception ex) + { + throw ex; + } - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -167,8 +265,9 @@ public async Task ReadThrowsGivenChunkPrefixGreaterThanMaxInt() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("80000000\r\n"); @@ -178,7 +277,6 @@ public async Task ReadThrowsGivenChunkPrefixGreaterThanMaxInt() Assert.IsType(ex.InnerException); Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -189,8 +287,9 @@ public async Task ReadThrowsGivenChunkPrefixGreaterThan8Bytes() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("012345678\r"); @@ -200,7 +299,6 @@ public async Task ReadThrowsGivenChunkPrefixGreaterThan8Bytes() Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -215,8 +313,9 @@ public async Task CanReadFromRemainingData(HttpVersion httpVersion) var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); var mockBodyControl = new Mock(); mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); - var stream = new HttpRequestStream(mockBodyControl.Object); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(mockBodyControl.Object, reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -228,7 +327,6 @@ public async Task CanReadFromRemainingData(HttpVersion httpVersion) input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -241,8 +339,9 @@ public async Task CanReadAsyncFromRemainingData(HttpVersion httpVersion) using (var input = new TestInput()) { var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -254,7 +353,6 @@ public async Task CanReadAsyncFromRemainingData(HttpVersion httpVersion) input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -269,14 +367,16 @@ public async Task ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders(), input.Http1Connection); var mockBodyControl = new Mock(); mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); - var stream = new HttpRequestStream(mockBodyControl.Object); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(mockBodyControl.Object, reader); + reader.StartAcceptingReads(body); input.Add("Hello"); var buffer = new byte[1024]; Assert.Equal(0, stream.Read(buffer, 0, buffer.Length)); + await body.StopAsync(); } } @@ -289,14 +389,16 @@ public async Task ReadAsyncFromNoContentLengthReturnsZero(HttpVersion httpVersio using (var input = new TestInput()) { var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders(), input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); var buffer = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + await body.StopAsync(); } } @@ -307,8 +409,9 @@ public async Task CanHandleLargeBlocks() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "8197" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); // Input needs to be greater than 4032 bytes to allocate a block not backed by a slab. var largeInput = new string('a', 8192); @@ -325,7 +428,7 @@ public async Task CanHandleLargeBlocks() Assert.Equal(8197, requestArray.Length); AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); + await body.StopAsync(); } } @@ -381,17 +484,19 @@ public async Task CopyToAsyncDoesNotCompletePipeReader() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); using (var ms = new MemoryStream()) { - await body.CopyToAsync(ms); + await stream.CopyToAsync(ms); } - Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + Assert.Equal(0, await stream.ReadAsync(new ArraySegment(new byte[1]))); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -407,81 +512,26 @@ public async Task ConsumeAsyncConsumesAllRemainingInput() await body.ConsumeAsync(); - Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + Assert.True((await body.ReadAsync()).IsCompleted); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } [Fact] - public async Task CopyToAsyncDoesNotCopyBlocks() + public async Task ConsumeAsyncAbortsConnectionInputAfterStartingTryReadWithoutAdvance() { - var writeCount = 0; - var writeTcs = new TaskCompletionSource<(byte[], int, int)>(TaskCreationOptions.RunContinuationsAsynchronously); - var mockDestination = new Mock { CallBase = true }; - - mockDestination - .Setup(m => m.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), CancellationToken.None)) - .Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => - { - writeTcs.SetResult((buffer, offset, count)); - writeCount++; - }) - .Returns(Task.CompletedTask); - - using (var memoryPool = KestrelMemoryPool.Create()) + using (var input = new TestInput()) { - var options = new PipeOptions(pool: memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); - var pair = DuplexPipe.CreateConnectionPair(options, options); - var transport = pair.Transport; - var http1ConnectionContext = new HttpConnectionContext - { - ServiceContext = new TestServiceContext(), - ConnectionFeatures = new FeatureCollection(), - Transport = transport, - MemoryPool = memoryPool, - TimeoutControl = Mock.Of() - }; - var http1Connection = new Http1Connection(http1ConnectionContext) - { - HasStartedConsumingRequestBody = true - }; - - var headers = new HttpRequestHeaders { HeaderContentLength = "12" }; - var body = Http1MessageBody.For(HttpVersion.Http11, headers, http1Connection); - - var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); - - var bytes = Encoding.ASCII.GetBytes("Hello "); - var buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); - Assert.True(MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)); - Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); - http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); - await http1Connection.RequestBodyPipe.Writer.FlushAsync(); - - // Verify the block passed to Stream.WriteAsync() is the same one incoming data was written into. - Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); - - // Verify the again when GetMemory returns the tail space of the same block. - writeTcs = new TaskCompletionSource<(byte[], int, int)>(TaskCreationOptions.RunContinuationsAsynchronously); - bytes = Encoding.ASCII.GetBytes("World!"); - buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); - Assert.True(MemoryMarshal.TryGetArray(buffer, out segment)); - Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); - http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); - await http1Connection.RequestBodyPipe.Writer.FlushAsync(); - - Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); + var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); - http1Connection.RequestBodyPipe.Writer.Complete(); + input.Add("Hello"); - await copyToAsyncTask; + body.TryRead(out var readResult); - Assert.Equal(2, writeCount); + await body.ConsumeAsync(); - // Don't call body.StopAsync() because PumpAsync() was never called. - http1Connection.RequestBodyPipe.Reader.Complete(); + await body.StopAsync(); } } @@ -494,9 +544,11 @@ public async Task ConnectionUpgradeKeepAlive(string headerConnection) { using (var input = new TestInput()) { + // note the http1connection request body pipe reader should be the same. var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = headerConnection }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -506,7 +558,6 @@ public async Task ConnectionUpgradeKeepAlive(string headerConnection) input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -522,8 +573,9 @@ public async Task UpgradeConnectionAcceptsContentLengthZero() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = headerConnection, ContentLength = 0 }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); input.Add("Hello"); @@ -533,7 +585,6 @@ public async Task UpgradeConnectionAcceptsContentLengthZero() input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -544,8 +595,9 @@ public async Task PumpAsyncDoesNotReturnAfterCancelingInput() using (var input = new TestInput()) { var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); // Add some input and consume it to ensure PumpAsync is running input.Add("a"); @@ -557,7 +609,6 @@ public async Task PumpAsyncDoesNotReturnAfterCancelingInput() input.Add("b"); Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -575,15 +626,16 @@ public async Task ReadAsyncThrowsOnTimeout() // Add some input and read it to start PumpAsync input.Add("a"); - Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + var readResult = await body.ReadAsync(); + Assert.Equal(1, readResult.Buffer.Length); + body.AdvanceTo(readResult.Buffer.End); // Time out on the next read input.Http1Connection.SendTimeoutResponse(); - var exception = await Assert.ThrowsAsync(async () => await body.ReadAsync(new Memory(new byte[1]))); + var exception = await Assert.ThrowsAsync(async () => await body.ReadAsync()); Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -603,7 +655,11 @@ public async Task ConsumeAsyncCompletesAndDoesNotThrowOnTimeout() // Add some input and read it to start PumpAsync input.Add("a"); - Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + var readResult = await body.ReadAsync(); + Assert.Equal(1, readResult.Buffer.Length); + + // need to advance to make PipeReader in ReadCompleted state + body.AdvanceTo(readResult.Buffer.End); // Time out on the next read input.Http1Connection.SendTimeoutResponse(); @@ -614,7 +670,6 @@ public async Task ConsumeAsyncCompletesAndDoesNotThrowOnTimeout() It.IsAny(), It.Is(ex => ex.Reason == RequestRejectionReason.RequestBodyTimeout))); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -629,21 +684,23 @@ public async Task CopyToAsyncThrowsOnTimeout() input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); // Add some input and read it to start PumpAsync input.Add("a"); - Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + Assert.Equal(1, (await body.ReadAsync()).Buffer.Length); // Time out on the next read input.Http1Connection.SendTimeoutResponse(); using (var ms = new MemoryStream()) { - var exception = await Assert.ThrowsAsync(() => body.CopyToAsync(ms)); + var exception = await Assert.ThrowsAsync(() => stream.CopyToAsync(ms)); Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); } - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -659,8 +716,9 @@ public async Task LogsWhenStartsReadingRequestBody() input.Http1Connection.TraceIdentifier = "RequestId"; var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); // Add some input and consume it to ensure PumpAsync is running input.Add("a"); @@ -670,7 +728,6 @@ public async Task LogsWhenStartsReadingRequestBody() input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -690,8 +747,9 @@ public async Task LogsWhenStopsReadingRequestBody() input.Http1Connection.TraceIdentifier = "RequestId"; var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); - var stream = new HttpRequestStream(Mock.Of()); - stream.StartAcceptingReads(body); + var reader = new HttpRequestPipeReader(); + var stream = new HttpRequestStream(Mock.Of(), reader); + reader.StartAcceptingReads(body); // Add some input and consume it to ensure PumpAsync is running input.Add("a"); @@ -699,7 +757,6 @@ public async Task LogsWhenStopsReadingRequestBody() input.Fin(); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); await logEvent.Task.DefaultTimeout(); @@ -717,13 +774,17 @@ public async Task PausesAndResumesRequestBodyTimeoutOnBackpressure() var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "12" }, input.Http1Connection); // Add some input and read it to start PumpAsync - var readTask1 = body.ReadAsync(new ArraySegment(new byte[6])); + var readTask1 = body.ReadAsync(); input.Add("hello,"); - Assert.Equal(6, await readTask1); + var readResult = await readTask1; + Assert.Equal(6, readResult.Buffer.Length); + body.AdvanceTo(readResult.Buffer.End); - var readTask2 = body.ReadAsync(new ArraySegment(new byte[6])); + var readTask2 = body.ReadAsync(); input.Add(" world"); - Assert.Equal(6, await readTask2); + readResult = await readTask2; + Assert.Equal(6, readResult.Buffer.Length); + body.AdvanceTo(readResult.Buffer.End); // Due to the limits set on HttpProtocol.RequestBodyPipe, backpressure should be triggered on every write to that pipe. mockTimeoutControl.Verify(timeoutControl => timeoutControl.StopTimingRead(), Times.Exactly(2)); @@ -751,14 +812,13 @@ public async Task OnlyEnforcesRequestBodyTimeoutAfterFirstRead() Assert.False(startRequestBodyCalled); // Add some input and read it to start PumpAsync - var readTask = body.ReadAsync(new ArraySegment(new byte[1])); + var readTask = body.ReadAsync(); Assert.True(startRequestBodyCalled); input.Add("a"); await readTask; - input.Http1Connection.RequestBodyPipe.Reader.Complete(); await body.StopAsync(); } } @@ -776,11 +836,15 @@ public async Task DoesNotEnforceRequestBodyTimeoutOnUpgradeRequests() // Add some input and read it to start PumpAsync input.Add("a"); - Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + var readResult = await body.ReadAsync(); + Assert.Equal(1, readResult.Buffer.Length); + + // need to advance to make PipeReader in ReadCompleted state + body.AdvanceTo(readResult.Buffer.End); input.Fin(); - Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + Assert.True((await body.ReadAsync()).IsCompleted); mockTimeoutControl.Verify(timeoutControl => timeoutControl.StartRequestBody(minReadRate), Times.Never); mockTimeoutControl.Verify(timeoutControl => timeoutControl.StopRequestBody(), Times.Never); @@ -791,7 +855,390 @@ public async Task DoesNotEnforceRequestBodyTimeoutOnUpgradeRequests() mockTimeoutControl.Verify(timeoutControl => timeoutControl.StopTimingRead(), Times.Never); mockTimeoutControl.Verify(timeoutControl => timeoutControl.StartTimingRead(), Times.Never); - input.Http1Connection.RequestBodyPipe.Reader.Complete(); + await body.StopAsync(); + } + } + + [Fact] + public async Task CancelPendingReadContentLengthWorks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var readResultTask = reader.ReadAsync(); + + reader.CancelPendingRead(); + + var readResult = await readResultTask; + + Assert.True(readResult.IsCanceled); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CancelPendingReadChunkedWorks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var readResultTask = reader.ReadAsync(); + + reader.CancelPendingRead(); + + var readResult = await readResultTask; + + Assert.True(readResult.IsCanceled); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CancelPendingReadUpgradeWorks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var readResultTask = reader.ReadAsync(); + + reader.CancelPendingRead(); + + var readResult = await readResultTask; + + Assert.True(readResult.IsCanceled); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CancelPendingReadForZeroContentLengthCannotBeCanceled() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders(), input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var readResultTask = reader.ReadAsync(); + + Assert.True(readResultTask.IsCompleted); + + reader.CancelPendingRead(); + + var readResult = await readResultTask; + + Assert.False(readResult.IsCanceled); + + await body.StopAsync(); + } + } + + [Fact] + public async Task TryReadReturnsCompletedResultAfterReadingEntireContentLength() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.True(readResult.IsCompleted); + + await body.StopAsync(); + } + } + + [Fact] + public async Task TryReadReturnsCompletedResultAfterReadingEntireChunk() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("5\r\nHello\r\n"); + + Assert.True(reader.TryRead(out var readResult)); + Assert.False(readResult.IsCompleted); + AssertASCII("Hello", readResult.Buffer); + + reader.AdvanceTo(readResult.Buffer.End); + + input.Add("0\r\n\r\n"); + Assert.True(reader.TryRead(out readResult)); + + Assert.True(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + await body.StopAsync(); + } + } + + [Fact] + public async Task TryReadDoesNotReturnCompletedReadResultFromUpgradeStreamUntilCompleted() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + + Assert.True(reader.TryRead(out var readResult)); + Assert.False(readResult.IsCompleted); + AssertASCII("Hello", readResult.Buffer); + + reader.AdvanceTo(readResult.Buffer.End); + + input.Fin(); + + reader.TryRead(out readResult); + Assert.True(readResult.IsCompleted); + + await body.StopAsync(); + } + } + + [Fact] + public async Task TryReadDoesReturnsCompletedReadResultForZeroContentLength() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders(), input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + + Assert.True(reader.TryRead(out var readResult)); + Assert.True(readResult.IsCompleted); + + reader.AdvanceTo(readResult.Buffer.End); + + reader.TryRead(out readResult); + Assert.True(readResult.IsCompleted); + + await body.StopAsync(); + } + } + + [Fact] // TODO + public async Task OnWriterCompletedForContentLengthDoesNotWork() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("Hello"); + var retVal = false; + + // Callback isn't fired at the moment. + reader.OnWriterCompleted((a, b) => retVal = true, null); + Assert.True(reader.TryRead(out var readResult)); + + Assert.True(readResult.IsCompleted); + Assert.False(retVal); + + await body.StopAsync(); + } + } + + [Fact] + public async Task OnWriterCompletedForChunkedWorks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var tcs = new TaskCompletionSource(); + reader.OnWriterCompleted((a, b) => tcs.SetResult(null), null); + + input.Add("0\r\n\r\n"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.True(readResult.IsCompleted); + Assert.Null(await tcs.Task.DefaultTimeout()); + + await body.StopAsync(); + } + } + + [Fact] + public async Task OnWriterCompletedForUpgradeWorks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var retVal = false; + reader.OnWriterCompleted((a, b) => retVal = true, null); + + input.Add("hi"); + + Assert.True(reader.TryRead(out var readResult)); + reader.AdvanceTo(readResult.Buffer.End); + + input.Fin(); + + Assert.True(retVal); + + await body.StopAsync(); + } + } + + [Fact] + public async Task OnWriterCompletedForNoContentLengthNoop() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders(), input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + var retVal = false; + reader.OnWriterCompleted((a, b) => retVal = true, null); + + input.Add("hi"); + + Assert.True(reader.TryRead(out var readResult)); + reader.AdvanceTo(readResult.Buffer.End); + + input.Fin(); + + Assert.False(retVal); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CompleteForContentLengthDoesNotCompleteConnectionPipeMakesReadReturnThrow() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("a"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.False(readResult.IsCompleted); + + input.Add("asdf"); + + reader.Complete(); + reader.AdvanceTo(readResult.Buffer.End); + + Assert.Throws(() => reader.TryRead(out readResult)); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CompleteForChunkedDoesNotCompleteConnectionPipeMakesReadThrow() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("5\r\nHello\r\n"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.False(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + input.Add("1\r\nH\r\n"); + + reader.Complete(); + + Assert.Throws(() => reader.TryRead(out readResult)); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CompleteForUpgradeDoesNotCompleteConnectionPipeMakesReadThrow() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("asdf"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.False(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + input.Add("asdf"); + + reader.Complete(); + + Assert.Throws(() => reader.TryRead(out readResult)); + + await body.StopAsync(); + } + } + + + [Fact] + public async Task CompleteForZeroByteBodyDoesNotCompleteConnectionPipeNoopsReads() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders(), input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("asdf"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.True(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + input.Add("asdf"); + + reader.Complete(); + + // TODO should this noop or throw? I think we should keep parity with normal pipe behavior. + // So maybe this should throw + reader.TryRead(out readResult); + await body.StopAsync(); } } @@ -807,6 +1254,18 @@ private void AssertASCII(string expected, ArraySegment actual) } } + private void AssertASCII(string expected, ReadOnlySequence actual) + { + var arr = actual.ToArray(); + var encoding = Encoding.ASCII; + var bytes = encoding.GetBytes(expected); + Assert.Equal(bytes.Length, actual.Length); + for (var index = 0; index < bytes.Length; index++) + { + Assert.Equal(bytes[index], arr[index]); + } + } + private class ThrowOnWriteSynchronousStream : Stream { public override void Flush() diff --git a/src/Servers/Kestrel/Core/test/StreamsTests.cs b/src/Servers/Kestrel/Core/test/StreamsTests.cs deleted file mode 100644 index a3c09848e338..000000000000 --- a/src/Servers/Kestrel/Core/test/StreamsTests.cs +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Moq; -using Xunit; - -namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests -{ - public class StreamsTests - { - [Fact] - public async Task StreamsThrowAfterAbort() - { - var streams = new Streams(Mock.Of(), new HttpResponsePipeWriter(Mock.Of())); - var (request, response) = streams.Start(new MockMessageBody()); - - var ex = new Exception("My error"); - streams.Abort(ex); - - await response.WriteAsync(new byte[1], 0, 1); - Assert.Same(ex, - await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); - } - - [Fact] - public async Task StreamsThrowOnAbortAfterUpgrade() - { - var streams = new Streams(Mock.Of(), new HttpResponsePipeWriter(Mock.Of())); - var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); - - var upgrade = streams.Upgrade(); - var ex = new Exception("My error"); - streams.Abort(ex); - - var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); - Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); - - Assert.Same(ex, - await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); - - Assert.Same(ex, - await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); - - await upgrade.WriteAsync(new byte[1], 0, 1); - } - - [Fact] - public async Task StreamsThrowOnUpgradeAfterAbort() - { - var streams = new Streams(Mock.Of(), new HttpResponsePipeWriter(Mock.Of())); - - var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); - var ex = new Exception("My error"); - streams.Abort(ex); - - var upgrade = streams.Upgrade(); - - var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); - Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); - - Assert.Same(ex, - await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); - - Assert.Same(ex, - await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); - - await upgrade.WriteAsync(new byte[1], 0, 1); - } - - private class MockMessageBody : MessageBody - { - public MockMessageBody(bool upgradeable = false) - : base(null, null) - { - RequestUpgrade = upgradeable; - } - } - } -} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ReadingBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ReadingBenchmark.cs new file mode 100644 index 000000000000..122c5b694c19 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ReadingBenchmark.cs @@ -0,0 +1,145 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class Http1ReadingBenchmark + { + // Standard completed task + private static readonly Func _syncTaskFunc = (obj) => Task.CompletedTask; + // Non-standard completed task + private static readonly Task _pseudoAsyncTask = Task.FromResult(27); + private static readonly Func _pseudoAsyncTaskFunc = (obj) => _pseudoAsyncTask; + + private TestHttp1Connection _http1Connection; + private DuplexPipe.DuplexPipePair _pair; + private MemoryPool _memoryPool; + + private readonly byte[] _readData = Encoding.ASCII.GetBytes(new string('a', 100)); + + [GlobalSetup] + public void GlobalSetup() + { + _memoryPool = KestrelMemoryPool.Create(); + _http1Connection = MakeHttp1Connection(); + } + + [Params(true, false)] + public bool WithHeaders { get; set; } + + //[Params(true, false)] + //public bool Chunked { get; set; } + + [Params(Startup.None, Startup.Sync, Startup.Async)] + public Startup OnStarting { get; set; } + + [IterationSetup] + public void Setup() + { + _http1Connection.Reset(); + + _http1Connection.RequestHeaders.ContentLength = _readData.Length; + + if (!WithHeaders) + { + _http1Connection.FlushAsync().GetAwaiter().GetResult(); + } + + ResetState(); + } + + private void ResetState() + { + if (WithHeaders) + { + _http1Connection.ResetState(); + + switch (OnStarting) + { + case Startup.Sync: + _http1Connection.OnStarting(_syncTaskFunc, null); + break; + case Startup.Async: + _http1Connection.OnStarting(_pseudoAsyncTaskFunc, null); + break; + } + } + } + + [Benchmark] + public Task ReadAsync() + { + ResetState(); + + return _http1Connection.ResponseBody.ReadAsync(new byte[100], default(CancellationToken)).AsTask(); + } + + private TestHttp1Connection MakeHttp1Connection() + { + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + _pair = pair; + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new TestHttp1Connection(new HttpConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = _memoryPool, + TimeoutControl = new TimeoutControl(timeoutHandler: null), + Transport = pair.Transport + }); + + http1Connection.Reset(); + http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(keepAlive: true, 100, http1Connection)); + serviceContext.DateHeaderValueManager.OnHeartbeat(DateTimeOffset.UtcNow); + + return http1Connection; + } + + [IterationCleanup] + public void Cleanup() + { + var reader = _pair.Application.Input; + if (reader.TryRead(out var readResult)) + { + reader.AdvanceTo(readResult.Buffer.End); + } + } + + public enum Startup + { + None, + Sync, + Async + } + + [GlobalCleanup] + public void Dispose() + { + _memoryPool?.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs index f272a24b6104..d0c1cf3370f7 100644 --- a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -119,7 +119,7 @@ private TestHttp1Connection MakeHttp1Connection() }); http1Connection.Reset(); - http1Connection.InitializeStreams(MessageBody.ZeroContentLengthKeepAlive); + http1Connection.InitializeBodyControl(MessageBody.ZeroContentLengthKeepAlive); serviceContext.DateHeaderValueManager.OnHeartbeat(DateTimeOffset.UtcNow); return http1Connection; diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs index 46d722aec939..e2320b2d28c4 100644 --- a/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs @@ -621,6 +621,7 @@ await connection.Send( [MemberData(nameof(ConnectionAdapterData))] public async Task RequestsCanBeAbortedMidRead(ListenOptions listenOptions) { + // This needs a timeout. const int applicationAbortedConnectionId = 34; var testContext = new TestServiceContext(LoggerFactory); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs index 03ba3d3c23a3..5d0be6b3e0cc 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; using System.Linq; @@ -35,6 +36,24 @@ private async Task App(HttpContext httpContext) } } + private async Task PipeApp(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + while (true) + { + var readResult = await request.BodyPipe.ReadAsync(); + if (readResult.IsCompleted) + { + break; + } + // Need to copy here. + await response.BodyPipe.WriteAsync(readResult.Buffer.ToArray()); + + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + } + } + private async Task AppChunked(HttpContext httpContext) { var request = httpContext.Request; @@ -76,6 +95,35 @@ await connection.ReceiveEnd( } } + [Fact] + public async Task Http10TransferEncodingPipes() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(PipeApp, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Host:", + "Transfer-Encoding: chunked", + "", + "5", "Hello", + "6", " World", + "0", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "", + "Hello World"); + } + } + } + [Fact] public async Task Http10KeepAliveTransferEncoding() { @@ -261,6 +309,95 @@ public async Task TrailingHeadersAreParsed() } } + [Fact] + public async Task TrailingHeadersAreParsedWithPipe() + { + var requestCount = 10; + var requestsReceived = 0; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + while (true) + { + var result = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(result.Buffer.End); + if (result.IsCompleted) + { + break; + } + } + + if (requestsReceived < requestCount) + { + Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"].ToString()); + } + else + { + Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"])); + } + + requestsReceived++; + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, new TestServiceContext(LoggerFactory))) + { + var response = string.Join("\r\n", new string[] { + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"}); + + var expectedFullResponse = string.Join("", Enumerable.Repeat(response, requestCount + 1)); + + IEnumerable sendSequence = new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", + "HelloChunked", + "0", + ""}; + + for (var i = 1; i < requestCount; i++) + { + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", + $"HelloChunk{i:00}", + "0", + string.Concat("X-Trailer-Header: ", new string('a', i)), + "" }); + } + + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Content-Length: 7", + "", + "Goodbye" + }); + + var fullRequest = sendSequence.ToArray(); + + using (var connection = server.CreateConnection()) + { + await connection.Send(fullRequest); + await connection.Receive(expectedFullResponse); + } + + await server.StopAsync(); + } + } [Fact] public async Task TrailingHeadersCountTowardsHeadersTotalSizeLimit() { @@ -677,6 +814,162 @@ await connection.SendAll( await server.StopAsync(); } } + + [Fact] + public async Task ChunkedRequestCallCancelPendingReadWorks() + { + var tcs = new TaskCompletionSource(); + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + var requestTask = httpContext.Request.BodyPipe.ReadAsync(); + + httpContext.Request.BodyPipe.CancelPendingRead(); + + Assert.True((await requestTask).IsCanceled); + + tcs.SetResult(null); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1", + "H"); + await tcs.Task; + await connection.Send( + "4", + "ello", + "0", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ChunkedRequestCallCompleteThrowsExceptionOnRead() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + httpContext.Request.BodyPipe.Complete(); + + await Assert.ThrowsAsync(async () => await request.BodyPipe.ReadAsync()); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1", + "H", + "4", + "ello", + "0", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ChunkedRequestCallCompleteWithExceptionCauses500() + { + var tcs = new TaskCompletionSource(); + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + httpContext.Request.BodyPipe.Complete(new Exception()); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1", + "H", + "0", + "", + ""); + + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } } } - diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index a35fdb388b5f..a427f7c4804c 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -727,6 +727,110 @@ await ExpectAsync(Http2FrameType.DATA, Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); } + [Fact] + public async Task ContentLength_Received_MultipleDataFrame_ReadViaPipe_Verified() + { + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.ContentLength, "12"), + }; + await InitializeConnectionAsync(async context => + { + var readResult = await context.Request.BodyPipe.ReadAsync(); + while (!readResult.IsCompleted) + { + context.Request.BodyPipe.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + readResult = await context.Request.BodyPipe.ReadAsync(); + } + + Assert.Equal(12, readResult.Buffer.Length); + context.Request.BodyPipe.AdvanceTo(readResult.Buffer.End); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[1], endStream: false); + await SendDataAsync(1, new byte[3], endStream: false); + await SendDataAsync(1, new byte[8], endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task ContentLength_Received_MultipleDataFrame_ReadViaPipeAndStream_Verified() + { + var tcs = new TaskCompletionSource(); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.ContentLength, "12"), + }; + await InitializeConnectionAsync(async context => + { + var readResult = await context.Request.BodyPipe.ReadAsync(); + Assert.Equal(1, readResult.Buffer.Length); + context.Request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + tcs.SetResult(null); + + var buffer = new byte[100]; + + var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length); + var total = read; + while (read > 0) + { + read = await context.Request.Body.ReadAsync(buffer, total, buffer.Length - total); + total += read; + } + + Assert.Equal(11, total); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[1], endStream: false); + await tcs.Task; + await SendDataAsync(1, new byte[3], endStream: false); + await SendDataAsync(1, new byte[8], endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + [Fact] public async Task ContentLength_Received_NoDataFrames_Reset() { @@ -911,6 +1015,53 @@ await InitializeConnectionAsync(async context => Assert.IsType(thrownEx.InnerException); } + [Fact] + public async Task ContentLength_Received_ReadViaPipes() + { + await InitializeConnectionAsync(async context => + { + var readResult = await context.Request.BodyPipe.ReadAsync(); + Assert.Equal(12, readResult.Buffer.Length); + Assert.True(readResult.IsCompleted); + context.Request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + readResult = await context.Request.BodyPipe.ReadAsync(); + Assert.True(readResult.IsCompleted); + }); + + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair("a", _4kHeaderValue), + new KeyValuePair("b", _4kHeaderValue), + new KeyValuePair("c", _4kHeaderValue), + new KeyValuePair("d", _4kHeaderValue), + new KeyValuePair(HeaderNames.ContentLength, "12"), + }; + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[12], endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + [Fact] // TODO https://github.com/aspnet/AspNetCore/issues/7034 public async Task ContentLength_Response_FirstWriteMoreBytesWritten_Throws_Sends500() { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs index c482c63b4f0e..a28fb1cabdae 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs @@ -18,7 +18,6 @@ using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging.Testing; -using Moq; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests @@ -612,9 +611,6 @@ public async Task ZeroContentLengthAssumedOnNonKeepAliveRequestsWithoutContentLe { using (var connection = server.CreateConnection()) { - // Use Send instead of SendEnd to ensure the connection will remain open while - // the app runs and reads 0 bytes from the body nonetheless. This checks that - // https://github.com/aspnet/KestrelHttpServer/issues/1104 is not regressing. await connection.Send( "GET / HTTP/1.1", "Host:", @@ -650,11 +646,95 @@ await connection.ReceiveEnd( } [Fact] - public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes() + public async Task ZeroContentLengthAssumedOnNonKeepAliveRequestsWithoutContentLengthOrTransferEncodingHeaderPipeReader() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyPipe.ReadAsync().AsTask().DefaultTimeout(); + // This will hang if 0 content length is not assumed by the server + Assert.True(readResult.IsCompleted); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthReadAsyncPipeReader() { var testContext = new TestServiceContext(LoggerFactory); - // FIN callbacks are scheduled so run inline to make this test more reliable - testContext.Scheduler = PipeScheduler.Inline; + + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyPipe.ReadAsync(); + // This will hang if 0 content length is not assumed by the server + Assert.Equal(5, readResult.Buffer.Length); + httpContext.Request.BodyPipe.AdvanceTo(readResult.Buffer.End); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "hello"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + + [Fact] + public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes() + { + var testContext = new TestServiceContext(LoggerFactory) + { + // FIN callbacks are scheduled so run inline to make this test more reliable + Scheduler = PipeScheduler.Inline + }; using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { @@ -1258,6 +1338,147 @@ await connection.Receive( } } + [Fact] + public async Task ContentLengthRequestCallCancelPendingReadWorks() + { + var tcs = new TaskCompletionSource(); + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + var requestTask = httpContext.Request.BodyPipe.ReadAsync(); + + httpContext.Request.BodyPipe.CancelPendingRead(); + + Assert.True((await requestTask).IsCanceled); + + tcs.SetResult(null); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "H"); + await tcs.Task; + await connection.Send( + "ello"); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthRequestCallCompleteThrowsExceptionOnRead() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + httpContext.Request.BodyPipe.Complete(); + + await Assert.ThrowsAsync(async () => await request.BodyPipe.ReadAsync()); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "Hello"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthCallCompleteWithExceptionCauses500() + { + var tcs = new TaskCompletionSource(); + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + var readResult = await request.BodyPipe.ReadAsync(); + request.BodyPipe.AdvanceTo(readResult.Buffer.End); + + httpContext.Request.BodyPipe.Complete(new Exception()); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.BodyPipe.WriteAsync(new Memory(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "Hello"); + + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } + public static TheoryData HostHeaderData => HttpParsingData.HostHeaderData; } } diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs index 4e7b1a689fb5..f61df14ff2c9 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs @@ -14,6 +14,7 @@ public static string GenerateFile() "IHttpRequestFeature", "IHttpResponseFeature", "IResponseBodyPipeFeature", + "IRequestBodyPipeFeature", "IHttpRequestIdentifierFeature", "IServiceProvidersFeature", "IHttpRequestLifetimeFeature", @@ -62,6 +63,7 @@ public static string GenerateFile() "IHttpRequestFeature", "IHttpResponseFeature", "IResponseBodyPipeFeature", + "IRequestBodyPipeFeature", "IHttpUpgradeFeature", "IHttpRequestIdentifierFeature", "IHttpRequestLifetimeFeature",