From 94179054fec5e6d295ec9685a372911571c6aa71 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 11 Feb 2022 21:05:33 -0500 Subject: [PATCH 1/2] Use static abstract interface methods in Ssl/NegotiateStream adapters --- .../System/Net/Security/NegotiateStream.cs | 70 +++++++------ .../System/Net/Security/ReadWriteAdapter.cs | 63 +++++------- .../Net/Security/SslStream.Implementation.cs | 97 +++++++++---------- .../src/System/Net/Security/SslStream.cs | 17 ++-- .../src/System/Net/StreamFramer.cs | 15 +-- 5 files changed, 124 insertions(+), 138 deletions(-) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs index 03be81907a3bc5..ba2cc708bab886 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs @@ -134,7 +134,7 @@ public virtual void AuthenticateAsServer(NetworkCredential credential, Protectio public virtual void AuthenticateAsServer(NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel) { ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel); - AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult(); + AuthenticateAsync(default(CancellationToken)).GetAwaiter().GetResult(); } public virtual IAsyncResult BeginAuthenticateAsServer(AsyncCallback? asyncCallback, object? asyncState) => @@ -172,7 +172,7 @@ public virtual void AuthenticateAsClient( NetworkCredential credential, ChannelBinding? binding, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel) { ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel); - AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult(); + AuthenticateAsync(default(CancellationToken)).GetAwaiter().GetResult(); } public virtual Task AuthenticateAsClientAsync() => @@ -195,7 +195,7 @@ public virtual Task AuthenticateAsClientAsync( TokenImpersonationLevel allowedImpersonationLevel) { ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel); - return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default)); + return AuthenticateAsync(default(CancellationToken)); } public virtual Task AuthenticateAsServerAsync() => @@ -211,7 +211,7 @@ public virtual Task AuthenticateAsServerAsync( NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel) { ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel); - return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default)); + return AuthenticateAsync(default(CancellationToken)); } public override bool IsAuthenticated => IsAuthenticatedCore; @@ -312,7 +312,7 @@ public override int Read(byte[] buffer, int offset, int count) return InnerStream.Read(buffer, offset, count); } - ValueTask vt = ReadAsync(new SyncReadWriteAdapter(InnerStream), new Memory(buffer, offset, count)); + ValueTask vt = ReadAsync(new Memory(buffer, offset, count), default(CancellationToken)); Debug.Assert(vt.IsCompleted, "Should have completed synchroously with sync adapter"); return vt.GetAwaiter().GetResult(); } @@ -327,7 +327,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel return InnerStream.ReadAsync(buffer, offset, count, cancellationToken); } - return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new Memory(buffer, offset, count)).AsTask(); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) @@ -338,10 +338,11 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return InnerStream.ReadAsync(buffer, cancellationToken); } - return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer); + return ReadAsync(buffer, cancellationToken); } - private async ValueTask ReadAsync(TAdapter adapter, Memory buffer, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter + private async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken, [CallerMemberName] string? callerName = null) + where TIOAdapter : IReadWriteAdapter { if (Interlocked.Exchange(ref _readInProgress, 1) == 1) { @@ -364,7 +365,7 @@ private async ValueTask ReadAsync(TAdapter adapter, Memory while (true) { - int readBytes = await ReadAllAsync(adapter, _readHeader, allowZeroRead: true).ConfigureAwait(false); + int readBytes = await ReadAllAsync(InnerStream, _readHeader, allowZeroRead: true, cancellationToken).ConfigureAwait(false); if (readBytes == 0) { return 0; @@ -389,7 +390,7 @@ private async ValueTask ReadAsync(TAdapter adapter, Memory _readBuffer = new byte[readBytes]; } - readBytes = await ReadAllAsync(adapter, new Memory(_readBuffer, 0, readBytes), allowZeroRead: false).ConfigureAwait(false); + readBytes = await ReadAllAsync(InnerStream, new Memory(_readBuffer, 0, readBytes), allowZeroRead: false, cancellationToken).ConfigureAwait(false); // Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_ // Decrypted data start from zero offset, the size can be shrunk after decryption. @@ -421,13 +422,13 @@ private async ValueTask ReadAsync(TAdapter adapter, Memory _readInProgress = 0; } - static async ValueTask ReadAllAsync(TAdapter adapter, Memory buffer, bool allowZeroRead) + static async ValueTask ReadAllAsync(Stream stream, Memory buffer, bool allowZeroRead, CancellationToken cancellationToken) { int read = 0; do { - int bytes = await adapter.ReadAsync(buffer).ConfigureAwait(false); + int bytes = await TIOAdapter.ReadAsync(stream, buffer, cancellationToken).ConfigureAwait(false); if (bytes == 0) { if (read != 0 || !allowZeroRead) @@ -457,7 +458,7 @@ public override void Write(byte[] buffer, int offset, int count) return; } - WriteAsync(new SyncReadWriteAdapter(InnerStream), new ReadOnlyMemory(buffer, offset, count)).GetAwaiter().GetResult(); + WriteAsync(new ReadOnlyMemory(buffer, offset, count), default(CancellationToken)).GetAwaiter().GetResult(); } /// A that represents the asynchronous read operation. @@ -471,7 +472,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return InnerStream.WriteAsync(buffer, offset, count, cancellationToken); } - return WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new ReadOnlyMemory(buffer, offset, count)); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); } /// A that represents the asynchronous read operation. @@ -483,10 +484,11 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo return InnerStream.WriteAsync(buffer, cancellationToken); } - return new ValueTask(WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer)); + return new ValueTask(WriteAsync(buffer, cancellationToken)); } - private async Task WriteAsync(TAdapter adapter, ReadOnlyMemory buffer) where TAdapter : IReadWriteAdapter + private async Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { if (Interlocked.Exchange(ref _writeInProgress, 1) == 1) { @@ -508,7 +510,7 @@ private async Task WriteAsync(TAdapter adapter, ReadOnlyMemory b throw new IOException(SR.net_io_encrypt, e); } - await adapter.WriteAsync(_writeBuffer, 0, encryptedBytes).ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, _writeBuffer, 0, encryptedBytes, cancellationToken).ConfigureAwait(false); buffer = buffer.Slice(chunkBytes); } } @@ -702,7 +704,8 @@ private void ThrowIfFailed(bool authSuccessCheck) } } - private async Task AuthenticateAsync(TAdapter adapter, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter + private async Task AuthenticateAsync(CancellationToken cancellationToken, [CallerMemberName] string? callerName = null) + where TIOAdapter : IReadWriteAdapter { Debug.Assert(_context != null); @@ -715,8 +718,8 @@ private async Task AuthenticateAsync(TAdapter adapter, [CallerMemberNa try { await (_context.IsServer ? - ReceiveBlobAsync(adapter) : // server should listen for a client blob - SendBlobAsync(adapter, message: null)).ConfigureAwait(false); // client should send the first blob + ReceiveBlobAsync(cancellationToken) : // server should listen for a client blob + SendBlobAsync(message: null, cancellationToken)).ConfigureAwait(false); // client should send the first blob } catch (Exception e) { @@ -751,7 +754,8 @@ private bool CheckSpn() } // Client authentication starts here, but server also loops through this method. - private async Task SendBlobAsync(TAdapter adapter, byte[]? message) where TAdapter : IReadWriteAdapter + private async Task SendBlobAsync(byte[]? message, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { Debug.Assert(_context != null); @@ -764,7 +768,7 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh if (exception != null) { // Signal remote side on a failed attempt. - await SendAuthResetSignalAndThrowAsync(adapter, message!, exception).ConfigureAwait(false); + await SendAuthResetSignalAndThrowAsync(message!, exception, cancellationToken).ConfigureAwait(false); Debug.Fail("Unreachable"); } @@ -782,7 +786,7 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh statusCode = (int)((uint)statusCode >> 8); } - await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false); + await SendAuthResetSignalAndThrowAsync(message, exception, cancellationToken).ConfigureAwait(false); Debug.Fail("Unreachable"); } @@ -798,7 +802,7 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh statusCode = (int)((uint)statusCode >> 8); } - await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false); + await SendAuthResetSignalAndThrowAsync(message, exception, cancellationToken).ConfigureAwait(false); Debug.Fail("Unreachable"); } @@ -816,7 +820,7 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh statusCode = (int)((uint)statusCode >> 8); } - await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false); + await SendAuthResetSignalAndThrowAsync(message, exception, cancellationToken).ConfigureAwait(false); Debug.Fail("Unreachable"); } @@ -840,7 +844,7 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh if (message != null) { //even if we are completed, there could be a blob for sending. - await _framer!.WriteMessageAsync(adapter, message).ConfigureAwait(false); + await _framer!.WriteMessageAsync(InnerStream, message, cancellationToken).ConfigureAwait(false); } if (HandshakeComplete && _remoteOk) @@ -849,15 +853,16 @@ private async Task SendBlobAsync(TAdapter adapter, byte[]? message) wh return; } - await ReceiveBlobAsync(adapter).ConfigureAwait(false); + await ReceiveBlobAsync(cancellationToken).ConfigureAwait(false); } // Server authentication starts here, but client also loops through this method. - private async Task ReceiveBlobAsync(TAdapter adapter) where TAdapter : IReadWriteAdapter + private async Task ReceiveBlobAsync(CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { Debug.Assert(_framer != null); - byte[]? message = await _framer.ReadMessageAsync(adapter).ConfigureAwait(false); + byte[]? message = await _framer.ReadMessageAsync(InnerStream, cancellationToken).ConfigureAwait(false); if (message == null) { // This is an EOF otherwise we would get at least *empty* message but not a null one. @@ -903,12 +908,13 @@ private async Task ReceiveBlobAsync(TAdapter adapter) where TAdapter : } // Not yet done, get a new blob and send it if any. - await SendBlobAsync(adapter, message).ConfigureAwait(false); + await SendBlobAsync(message, cancellationToken).ConfigureAwait(false); } // This is to reset auth state on the remote side. // If this write succeeds we will allow auth retrying. - private async Task SendAuthResetSignalAndThrowAsync(TAdapter adapter, byte[] message, Exception exception) where TAdapter : IReadWriteAdapter + private async Task SendAuthResetSignalAndThrowAsync(byte[] message, Exception exception, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { _framer!.WriteHeader.MessageId = FrameHeader.HandshakeErrId; @@ -922,7 +928,7 @@ private async Task SendAuthResetSignalAndThrowAsync(TAdapter adapter, exception = new AuthenticationException(SR.net_auth_SSPI, exception); } - await _framer.WriteMessageAsync(adapter, message).ConfigureAwait(false); + await _framer.WriteMessageAsync(InnerStream, message, cancellationToken).ConfigureAwait(false); _canRetryAuthentication = true; ExceptionDispatchInfo.Throw(exception); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs b/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs index 2833fe1c7d5bce..32f1f3321677bd 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs @@ -7,69 +7,50 @@ namespace System.Net.Security { +#pragma warning disable CA2252 // This API requires opting into preview features internal interface IReadWriteAdapter { - ValueTask ReadAsync(Memory buffer); - - ValueTask WriteAsync(byte[] buffer, int offset, int count); - - Task WaitAsync(TaskCompletionSource waiter); - - Task FlushAsync(); - - CancellationToken CancellationToken { get; } + static abstract ValueTask ReadAsync(Stream stream, Memory buffer, CancellationToken cancellationToken); + static abstract ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken); + static abstract Task FlushAsync(Stream stream, CancellationToken cancellationToken); + static abstract Task WaitAsync(TaskCompletionSource waiter); } +#pragma warning restore CA2252 - internal readonly struct AsyncReadWriteAdapter : IReadWriteAdapter + internal sealed class AsyncReadWriteAdapter : IReadWriteAdapter { - private readonly Stream _stream; - - public AsyncReadWriteAdapter(Stream stream, CancellationToken cancellationToken) - { - _stream = stream; - CancellationToken = cancellationToken; - } + public static ValueTask ReadAsync(Stream stream, Memory buffer, CancellationToken cancellationToken) => + stream.ReadAsync(buffer, cancellationToken); - public ValueTask ReadAsync(Memory buffer) => - _stream.ReadAsync(buffer, CancellationToken); + public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + stream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); - public ValueTask WriteAsync(byte[] buffer, int offset, int count) => - _stream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), CancellationToken); + public static Task FlushAsync(Stream stream, CancellationToken cancellationToken) => stream.FlushAsync(cancellationToken); - public Task WaitAsync(TaskCompletionSource waiter) => waiter.Task; - - public Task FlushAsync() => _stream.FlushAsync(CancellationToken); - - public CancellationToken CancellationToken { get; } + public static Task WaitAsync(TaskCompletionSource waiter) => waiter.Task; } - internal readonly struct SyncReadWriteAdapter : IReadWriteAdapter + internal sealed class SyncReadWriteAdapter : IReadWriteAdapter { - private readonly Stream _stream; - - public SyncReadWriteAdapter(Stream stream) => _stream = stream; - - public ValueTask ReadAsync(Memory buffer) => - new ValueTask(_stream.Read(buffer.Span)); + public static ValueTask ReadAsync(Stream stream, Memory buffer, CancellationToken cancellationToken) => + new ValueTask(stream.Read(buffer.Span)); - public ValueTask WriteAsync(byte[] buffer, int offset, int count) + public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - _stream.Write(buffer, offset, count); + stream.Write(buffer, offset, count); return default; } - public Task WaitAsync(TaskCompletionSource waiter) + public static Task FlushAsync(Stream stream, CancellationToken cancellationToken) { - waiter.Task.GetAwaiter().GetResult(); + stream.Flush(); return Task.CompletedTask; } - public Task FlushAsync() + public static Task WaitAsync(TaskCompletionSource waiter) { - _stream.Flush(); + waiter.Task.GetAwaiter().GetResult(); return Task.CompletedTask; } - - public CancellationToken CancellationToken => default; } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 57f0718b20f8b9..91e1de1cebb73a 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -172,8 +172,8 @@ private Task ProcessAuthenticationAsync(bool isAsync = false, bool isApm = false else { return isAsync ? - ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm) : - ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null); + ForceAuthenticationAsync(_context!.IsServer, null, isApm, cancellationToken) : + ForceAuthenticationAsync(_context!.IsServer, null, isApm: false, cancellationToken); } } @@ -185,8 +185,8 @@ private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, bool is try { Task task = isAsync? - ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm) : - ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null); + ForceAuthenticationAsync(_context!.IsServer, null, isApm, cancellationToken) : + ForceAuthenticationAsync(_context!.IsServer, null, isApm: false, cancellationToken); await task.ConfigureAwait(false); @@ -206,12 +206,12 @@ private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, bool is // // This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read(). // - private async Task ReplyOnReAuthenticationAsync(TIOAdapter adapter, byte[]? buffer) + private async Task ReplyOnReAuthenticationAsync(byte[]? buffer, CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { try { - await ForceAuthenticationAsync(adapter, receiveFirst: false, buffer).ConfigureAwait(false); + await ForceAuthenticationAsync(receiveFirst: false, buffer, isApm: false, cancellationToken).ConfigureAwait(false); } finally { @@ -221,7 +221,7 @@ private async Task ReplyOnReAuthenticationAsync(TIOAdapter adapter, } // This will initiate renegotiation or PHA for Tls1.3 - private async Task RenegotiateAsync(TIOAdapter adapter) + private async Task RenegotiateAsync(CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) @@ -253,10 +253,10 @@ private async Task RenegotiateAsync(TIOAdapter adapter) SecurityStatusPal status = _context!.Renegotiate(out byte[]? nextmsg); - if (nextmsg is {} && nextmsg.Length > 0) + if (nextmsg is { Length: > 0 }) { - await adapter.WriteAsync(nextmsg, 0, nextmsg.Length).ConfigureAwait(false); - await adapter.FlushAsync().ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, nextmsg, 0, nextmsg.Length, cancellationToken).ConfigureAwait(false); + await TIOAdapter.FlushAsync(InnerStream, cancellationToken).ConfigureAwait(false); } if (status.ErrorCode != SecurityStatusPalErrorCode.OK) @@ -274,11 +274,11 @@ private async Task RenegotiateAsync(TIOAdapter adapter) ProtocolToken message = null!; do { - message = await ReceiveBlobAsync(adapter).ConfigureAwait(false); + message = await ReceiveBlobAsync(cancellationToken).ConfigureAwait(false); if (message.Size > 0) { - await adapter.WriteAsync(message.Payload!, 0, message.Size).ConfigureAwait(false); - await adapter.FlushAsync().ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, message.Payload!, 0, message.Size, cancellationToken).ConfigureAwait(false); + await TIOAdapter.FlushAsync(InnerStream, cancellationToken).ConfigureAwait(false); } } while (message.Status.ErrorCode == SecurityStatusPalErrorCode.ContinueNeeded); @@ -299,8 +299,8 @@ private async Task RenegotiateAsync(TIOAdapter adapter) } // reAuthenticationData is only used on Windows in case of renegotiation. - private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool receiveFirst, byte[]? reAuthenticationData, bool isApm = false) - where TIOAdapter : IReadWriteAdapter + private async Task ForceAuthenticationAsync(bool receiveFirst, byte[]? reAuthenticationData, bool isApm, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { ProtocolToken message; bool handshakeCompleted = false; @@ -316,14 +316,13 @@ private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool try { - if (!receiveFirst) { message = _context!.NextMessage(reAuthenticationData); if (message.Size > 0) { - await adapter.WriteAsync(message.Payload!, 0, message.Size).ConfigureAwait(false); - await adapter.FlushAsync().ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, message.Payload!, 0, message.Size, cancellationToken).ConfigureAwait(false); + await TIOAdapter.FlushAsync(InnerStream, cancellationToken).ConfigureAwait(false); if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SentFrame(this, message.Payload); } @@ -347,7 +346,7 @@ private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool while (!handshakeCompleted) { - message = await ReceiveBlobAsync(adapter).ConfigureAwait(false); + message = await ReceiveBlobAsync(cancellationToken).ConfigureAwait(false); byte[]? payload = null; int size = 0; @@ -366,8 +365,8 @@ private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool if (payload != null && size > 0) { // If there is message send it out even if call failed. It may contain TLS Alert. - await adapter.WriteAsync(payload!, 0, size).ConfigureAwait(false); - await adapter.FlushAsync().ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, payload!, 0, size, cancellationToken).ConfigureAwait(false); + await TIOAdapter.FlushAsync(InnerStream, cancellationToken).ConfigureAwait(false); if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SentFrame(this, payload); @@ -416,10 +415,10 @@ private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool } - private async ValueTask ReceiveBlobAsync(TIOAdapter adapter) - where TIOAdapter : IReadWriteAdapter + private async ValueTask ReceiveBlobAsync(CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { - int frameSize = await EnsureFullTlsFrameAsync(adapter).ConfigureAwait(false); + int frameSize = await EnsureFullTlsFrameAsync(cancellationToken).ConfigureAwait(false); if (frameSize == 0) { @@ -462,7 +461,7 @@ private async ValueTask ReceiveBlobAsync(TIOAdapter a { SslServerAuthenticationOptions userOptions = await _sslAuthenticationOptions.ServerOptionDelegate(this, new SslClientHelloInfo(_sslAuthenticationOptions.TargetHost, _lastFrame.SupportedVersions), - _sslAuthenticationOptions.UserState, adapter.CancellationToken).ConfigureAwait(false); + _sslAuthenticationOptions.UserState, cancellationToken).ConfigureAwait(false); _sslAuthenticationOptions.UpdateOptions(userOptions); } } @@ -598,19 +597,19 @@ private void CompleteHandshake(SslAuthenticationOptions sslAuthenticationOptions } } - private async ValueTask WriteAsyncChunked(TIOAdapter writeAdapter, ReadOnlyMemory buffer) - where TIOAdapter : struct, IReadWriteAdapter + private async ValueTask WriteAsyncChunked(ReadOnlyMemory buffer, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { do { int chunkBytes = Math.Min(buffer.Length, MaxDataSize); - await WriteSingleChunk(writeAdapter, buffer.Slice(0, chunkBytes)).ConfigureAwait(false); + await WriteSingleChunk(buffer.Slice(0, chunkBytes), cancellationToken).ConfigureAwait(false); buffer = buffer.Slice(chunkBytes); } while (buffer.Length != 0); } - private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnlyMemory buffer) - where TIOAdapter : struct, IReadWriteAdapter + private ValueTask WriteSingleChunk(ReadOnlyMemory buffer, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { byte[] rentedBuffer = ArrayPool.Shared.Rent(buffer.Length + FrameOverhead); byte[] outBuffer = rentedBuffer; @@ -630,7 +629,7 @@ private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnly TaskCompletionSource? waiter = _handshakeWaiter; if (waiter != null) { - Task waiterTask = writeAdapter.WaitAsync(waiter); + Task waiterTask = TIOAdapter.WaitAsync(waiter); // We finished synchronously waiting for renegotiation. We can try again immediately. if (waiterTask.IsCompletedSuccessfully) { @@ -638,7 +637,7 @@ private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnly } // We need to wait asynchronously as well as for the write when EncryptData is finished. - return WaitAndWriteAsync(writeAdapter, buffer, waiterTask, rentedBuffer); + return WaitAndWriteAsync(buffer, waiterTask, rentedBuffer, cancellationToken); } } @@ -648,7 +647,7 @@ private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnly return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, SslStreamPal.GetException(status)))); } - ValueTask t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes); + ValueTask t = TIOAdapter.WriteAsync(InnerStream, outBuffer, 0, encryptedBytes, cancellationToken); if (t.IsCompletedSuccessfully) { ArrayPool.Shared.Return(rentedBuffer); @@ -659,7 +658,7 @@ private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnly return CompleteWriteAsync(t, rentedBuffer); } - async ValueTask WaitAndWriteAsync(TIOAdapter writeAdapter, ReadOnlyMemory buffer, Task waitTask, byte[] rentedBuffer) + async ValueTask WaitAndWriteAsync(ReadOnlyMemory buffer, Task waitTask, byte[] rentedBuffer, CancellationToken cancellationToken) { byte[]? bufferToReturn = rentedBuffer; byte[] outBuffer = rentedBuffer; @@ -678,11 +677,11 @@ async ValueTask WaitAndWriteAsync(TIOAdapter writeAdapter, ReadOnlyMemory // Call WriteSingleChunk() recursively to avoid code duplication. // This should be extremely rare in cases when second renegotiation happens concurrently with Write. - await WriteSingleChunk(writeAdapter, buffer).ConfigureAwait(false); + await WriteSingleChunk(buffer, cancellationToken).ConfigureAwait(false); } else if (status.ErrorCode == SecurityStatusPalErrorCode.OK) { - await writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes).ConfigureAwait(false); + await TIOAdapter.WriteAsync(InnerStream, outBuffer, 0, encryptedBytes, cancellationToken).ConfigureAwait(false); } else { @@ -698,7 +697,7 @@ async ValueTask WaitAndWriteAsync(TIOAdapter writeAdapter, ReadOnlyMemory } } - async ValueTask CompleteWriteAsync(ValueTask writeTask, byte[] bufferToReturn) + static async ValueTask CompleteWriteAsync(ValueTask writeTask, byte[] bufferToReturn) { try { @@ -737,7 +736,7 @@ private bool HaveFullTlsFrame(out int frameSize) } - private async ValueTask EnsureFullTlsFrameAsync(TIOAdapter adapter) + private async ValueTask EnsureFullTlsFrameAsync(CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { int frameSize; @@ -752,7 +751,7 @@ private async ValueTask EnsureFullTlsFrameAsync(TIOAdapter adap Debug.Assert(_buffer.AvailableLength > 0, "_buffer.AvailableBytes > 0"); // We either don't have full frame or we don't have enough data to even determine the size. - int bytesRead = await adapter.ReadAsync(_buffer.AvailableMemory).ConfigureAwait(false); + int bytesRead = await TIOAdapter.ReadAsync(InnerStream, _buffer.AvailableMemory, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { if (_buffer.EncryptedLength != 0) @@ -816,7 +815,7 @@ private SecurityStatusPal DecryptData(int frameSize) return status; } - private async ValueTask ReadAsyncInternal(TIOAdapter adapter, Memory buffer) + private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { if (Interlocked.Exchange(ref _nestedRead, 1) == 1) @@ -859,7 +858,7 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M // until data is actually available from the underlying stream. // Note that if the underlying stream does not supporting blocking on zero byte reads, then this will // complete immediately and won't save any memory, but will still function correctly. - await adapter.ReadAsync(Memory.Empty).ConfigureAwait(false); + await TIOAdapter.ReadAsync(InnerStream, Memory.Empty, cancellationToken).ConfigureAwait(false); } Debug.Assert(_buffer.DecryptedLength == 0); @@ -868,7 +867,7 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M while (true) { - payloadBytes = await EnsureFullTlsFrameAsync(adapter).ConfigureAwait(false); + payloadBytes = await EnsureFullTlsFrameAsync(cancellationToken).ConfigureAwait(false); if (payloadBytes == 0) { _receivedEOF = true; @@ -897,7 +896,7 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M { throw new IOException(SR.net_ssl_io_renego); } - await ReplyOnReAuthenticationAsync(adapter, extraBuffer).ConfigureAwait(false); + await ReplyOnReAuthenticationAsync(extraBuffer, cancellationToken).ConfigureAwait(false); // Loop on read. continue; } @@ -951,7 +950,7 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M } catch (Exception e) { - if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested)) + if (e is IOException || (e is OperationCanceledException && cancellationToken.IsCancellationRequested)) { throw; } @@ -965,8 +964,8 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M } } - private async ValueTask WriteAsyncInternal(TIOAdapter writeAdapter, ReadOnlyMemory buffer) - where TIOAdapter : struct, IReadWriteAdapter + private async ValueTask WriteAsyncInternal(ReadOnlyMemory buffer, CancellationToken cancellationToken) + where TIOAdapter : IReadWriteAdapter { ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); @@ -984,13 +983,13 @@ private async ValueTask WriteAsyncInternal(TIOAdapter writeAdapter, try { ValueTask t = buffer.Length < MaxDataSize ? - WriteSingleChunk(writeAdapter, buffer) : - WriteAsyncChunked(writeAdapter, buffer); + WriteSingleChunk(buffer, cancellationToken) : + WriteAsyncChunked(buffer, cancellationToken); await t.ConfigureAwait(false); } catch (Exception e) { - if (e is IOException || (e is OperationCanceledException && writeAdapter.CancellationToken.IsCancellationRequested)) + if (e is IOException || (e is OperationCanceledException && cancellationToken.IsCancellationRequested)) { throw; } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 95acf46deea818..4428b8a53c2c83 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -50,7 +50,6 @@ public partial class SslStream : AuthenticatedStream internal LocalCertSelectionCallback? _certSelectionDelegate; internal EncryptionPolicy _encryptionPolicy; - private readonly Stream _innerStream; private SecureChannel? _context; private ExceptionDispatchInfo? _exception; @@ -222,8 +221,6 @@ public SslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificat _encryptionPolicy = encryptionPolicy; _certSelectionDelegate = userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); - _innerStream = innerStream; - if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SslStreamCtor(this, innerStream); } @@ -802,7 +799,7 @@ public virtual Task NegotiateClientCertificateAsync(CancellationToken cancellati throw new InvalidOperationException(SR.net_ssl_certificate_exist); } - return RenegotiateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken)); + return RenegotiateAsync(cancellationToken); } protected override void Dispose(bool disposing) @@ -869,7 +866,7 @@ public override int Read(byte[] buffer, int offset, int count) { ThrowIfExceptionalOrNotAuthenticated(); ValidateBufferArguments(buffer, offset, count); - ValueTask vt = ReadAsyncInternal(new SyncReadWriteAdapter(InnerStream), new Memory(buffer, offset, count)); + ValueTask vt = ReadAsyncInternal(new Memory(buffer, offset, count), default(CancellationToken)); Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously"); return vt.GetAwaiter().GetResult(); } @@ -881,7 +878,7 @@ public override void Write(byte[] buffer, int offset, int count) ThrowIfExceptionalOrNotAuthenticated(); ValidateBufferArguments(buffer, offset, count); - ValueTask vt = WriteAsyncInternal(new SyncReadWriteAdapter(InnerStream), new ReadOnlyMemory(buffer, offset, count)); + ValueTask vt = WriteAsyncInternal(new ReadOnlyMemory(buffer, offset, count), default(CancellationToken)); Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously"); vt.GetAwaiter().GetResult(); } @@ -914,26 +911,26 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati { ThrowIfExceptionalOrNotAuthenticated(); ValidateBufferArguments(buffer, offset, count); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + return WriteAsyncInternal(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { ThrowIfExceptionalOrNotAuthenticated(); - return WriteAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer); + return WriteAsyncInternal(buffer, cancellationToken); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ThrowIfExceptionalOrNotAuthenticated(); ValidateBufferArguments(buffer, offset, count); - return ReadAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new Memory(buffer, offset, count)).AsTask(); + return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); } public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { ThrowIfExceptionalOrNotAuthenticated(); - return ReadAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer); + return ReadAsyncInternal(buffer, cancellationToken); } private void ThrowIfExceptional() diff --git a/src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs b/src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs index 64da83709bc025..9b97a29888afac 100644 --- a/src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs +++ b/src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Net.Security; using System.Threading.Tasks; +using System.Threading; namespace System.Net { @@ -20,7 +21,8 @@ internal sealed class StreamFramer public FrameHeader ReadHeader => _curReadHeader; public FrameHeader WriteHeader => _writeHeader; - public async ValueTask ReadMessageAsync(TAdapter adapter) where TAdapter : IReadWriteAdapter + public async ValueTask ReadMessageAsync(Stream stream, CancellationToken cancellationToken) + where TAdapter : IReadWriteAdapter { if (_eof) { @@ -33,7 +35,7 @@ internal sealed class StreamFramer int offset = 0; while (offset < buffer.Length) { - bytesRead = await adapter.ReadAsync(buffer.AsMemory(offset)).ConfigureAwait(false); + bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { if (offset == 0) @@ -62,7 +64,7 @@ internal sealed class StreamFramer offset = 0; while (offset < buffer.Length) { - bytesRead = await adapter.ReadAsync(buffer.AsMemory(offset)).ConfigureAwait(false); + bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed)); @@ -73,15 +75,16 @@ internal sealed class StreamFramer return buffer; } - public async Task WriteMessageAsync(TAdapter adapter, byte[] message!!) where TAdapter : IReadWriteAdapter + public async Task WriteMessageAsync(Stream stream, byte[] message, CancellationToken cancellationToken) + where TAdapter : IReadWriteAdapter { _writeHeader.PayloadSize = message.Length; _writeHeader.CopyTo(_writeHeaderBuffer, 0); - await adapter.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length).ConfigureAwait(false); + await TAdapter.WriteAsync(stream, _writeHeaderBuffer, 0, _writeHeaderBuffer.Length, cancellationToken).ConfigureAwait(false); if (message.Length != 0) { - await adapter.WriteAsync(message, 0, message.Length).ConfigureAwait(false); + await TAdapter.WriteAsync(stream, message, 0, message.Length, cancellationToken).ConfigureAwait(false); } } } From eebc60768efbc71efbd92d6dc46c62168336eb30 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 11 Feb 2022 21:14:24 -0500 Subject: [PATCH 2/2] Remove unnecessary argument from a bunch of methods It's only used to add duplicative information to an exception message, and in doing so it makes the async methods it's used in more expensive. --- .../src/Resources/Strings.resx | 2 +- .../System/Net/Security/NegotiateStream.cs | 10 ++-- .../System/Net/Security/ReadWriteAdapter.cs | 4 +- .../Net/Security/SslStream.Implementation.cs | 47 ++++++++++--------- .../src/System/Net/Security/SslStream.cs | 12 ++--- .../Fakes/FakeSslStream.Implementation.cs | 11 +++-- 6 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/libraries/System.Net.Security/src/Resources/Strings.resx b/src/libraries/System.Net.Security/src/Resources/Strings.resx index 5373ba6ee33071..dd3b86e0b682bd 100644 --- a/src/libraries/System.Net.Security/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Security/src/Resources/Strings.resx @@ -135,7 +135,7 @@ The connection was closed - The {0} method cannot be called when another {1} operation is pending. + This method may not be called when another {0} operation is pending. {0} can only be called once for each asynchronous operation. diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs index ba2cc708bab886..447cf68d89bbf4 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs @@ -341,12 +341,12 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return ReadAsync(buffer, cancellationToken); } - private async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken, [CallerMemberName] string? callerName = null) + private async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { if (Interlocked.Exchange(ref _readInProgress, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, callerName, "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } try @@ -492,7 +492,7 @@ private async Task WriteAsync(ReadOnlyMemory buffer, Cancellat { if (Interlocked.Exchange(ref _writeInProgress, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(Write), "write")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write")); } try @@ -704,7 +704,7 @@ private void ThrowIfFailed(bool authSuccessCheck) } } - private async Task AuthenticateAsync(CancellationToken cancellationToken, [CallerMemberName] string? callerName = null) + private async Task AuthenticateAsync(CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { Debug.Assert(_context != null); @@ -712,7 +712,7 @@ private async Task AuthenticateAsync(CancellationToken cancellationT ThrowIfFailed(authSuccessCheck: false); if (Interlocked.Exchange(ref _authInProgress, 1) == 1) { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, callerName, "authenticate")); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate")); } try diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs b/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs index 32f1f3321677bd..213f4954b0b3da 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs @@ -17,7 +17,7 @@ internal interface IReadWriteAdapter } #pragma warning restore CA2252 - internal sealed class AsyncReadWriteAdapter : IReadWriteAdapter + internal readonly struct AsyncReadWriteAdapter : IReadWriteAdapter { public static ValueTask ReadAsync(Stream stream, Memory buffer, CancellationToken cancellationToken) => stream.ReadAsync(buffer, cancellationToken); @@ -30,7 +30,7 @@ public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int public static Task WaitAsync(TaskCompletionSource waiter) => waiter.Task; } - internal sealed class SyncReadWriteAdapter : IReadWriteAdapter + internal readonly struct SyncReadWriteAdapter : IReadWriteAdapter { public static ValueTask ReadAsync(Stream stream, Memory buffer, CancellationToken cancellationToken) => new ValueTask(stream.Read(buffer.Span)); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 91e1de1cebb73a..9dc0538cfca292 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -161,23 +161,23 @@ private SecurityStatusPal EncryptData(ReadOnlyMemory buffer, ref byte[] ou // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private Task ProcessAuthenticationAsync(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) + private Task ProcessAuthenticationAsync(bool isAsync = false, CancellationToken cancellationToken = default) { ThrowIfExceptional(); if (NetSecurityTelemetry.Log.IsEnabled()) { - return ProcessAuthenticationWithTelemetryAsync(isAsync, isApm, cancellationToken); + return ProcessAuthenticationWithTelemetryAsync(isAsync, cancellationToken); } else { return isAsync ? - ForceAuthenticationAsync(_context!.IsServer, null, isApm, cancellationToken) : - ForceAuthenticationAsync(_context!.IsServer, null, isApm: false, cancellationToken); + ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken) : + ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken); } } - private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, bool isApm, CancellationToken cancellationToken) + private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, CancellationToken cancellationToken) { NetSecurityTelemetry.Log.HandshakeStart(_context!.IsServer, _sslAuthenticationOptions!.TargetHost); ValueStopwatch stopwatch = ValueStopwatch.StartNew(); @@ -185,8 +185,8 @@ private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, bool is try { Task task = isAsync? - ForceAuthenticationAsync(_context!.IsServer, null, isApm, cancellationToken) : - ForceAuthenticationAsync(_context!.IsServer, null, isApm: false, cancellationToken); + ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken) : + ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken); await task.ConfigureAwait(false); @@ -211,7 +211,7 @@ private async Task ReplyOnReAuthenticationAsync(byte[]? buffer, Canc { try { - await ForceAuthenticationAsync(receiveFirst: false, buffer, isApm: false, cancellationToken).ConfigureAwait(false); + await ForceAuthenticationAsync(receiveFirst: false, buffer, cancellationToken).ConfigureAwait(false); } finally { @@ -226,18 +226,18 @@ private async Task RenegotiateAsync(CancellationToken cancellationTo { if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "NegotiateClientCertificateAsync", "renegotiate")); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate")); } if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(SslStream.ReadAsync), "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) { _nestedRead = 0; - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(WriteAsync), "write")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write")); } try @@ -272,15 +272,17 @@ private async Task RenegotiateAsync(CancellationToken cancellationTo _buffer.EnsureAvailableSpace(InitialHandshakeBufferSize); - ProtocolToken message = null!; - do { + ProtocolToken message; + do + { message = await ReceiveBlobAsync(cancellationToken).ConfigureAwait(false); if (message.Size > 0) { await TIOAdapter.WriteAsync(InnerStream, message.Payload!, 0, message.Size, cancellationToken).ConfigureAwait(false); await TIOAdapter.FlushAsync(InnerStream, cancellationToken).ConfigureAwait(false); } - } while (message.Status.ErrorCode == SecurityStatusPalErrorCode.ContinueNeeded); + } + while (message.Status.ErrorCode == SecurityStatusPalErrorCode.ContinueNeeded); CompleteHandshake(_sslAuthenticationOptions!); } @@ -299,7 +301,7 @@ private async Task RenegotiateAsync(CancellationToken cancellationTo } // reAuthenticationData is only used on Windows in case of renegotiation. - private async Task ForceAuthenticationAsync(bool receiveFirst, byte[]? reAuthenticationData, bool isApm, CancellationToken cancellationToken) + private async Task ForceAuthenticationAsync(bool receiveFirst, byte[]? reAuthenticationData, CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { ProtocolToken message; @@ -310,7 +312,7 @@ private async Task ForceAuthenticationAsync(bool receiveFirst, byte[ // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently. if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate")); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate")); } } @@ -820,20 +822,19 @@ private async ValueTask ReadAsyncInternal(Memory buffer, { if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(SslStream.ReadAsync), "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } ThrowIfExceptionalOrNotAuthenticated(); - int processedLength = 0; - int payloadBytes = 0; - try { + int processedLength = 0; + if (_buffer.DecryptedLength != 0) { processedLength = CopyDecryptedData(buffer); - if (processedLength == buffer.Length || !HaveFullTlsFrame(out payloadBytes)) + if (processedLength == buffer.Length || !HaveFullTlsFrame(out _)) { // We either filled whole buffer or used all buffered frames. return processedLength; @@ -867,7 +868,7 @@ private async ValueTask ReadAsyncInternal(Memory buffer, while (true) { - payloadBytes = await EnsureFullTlsFrameAsync(cancellationToken).ConfigureAwait(false); + int payloadBytes = await EnsureFullTlsFrameAsync(cancellationToken).ConfigureAwait(false); if (payloadBytes == 0) { _receivedEOF = true; @@ -977,7 +978,7 @@ private async ValueTask WriteAsyncInternal(ReadOnlyMemory buff if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(WriteAsync), "write")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write")); } try diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 4428b8a53c2c83..4879d175d15249 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -472,7 +472,7 @@ public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAu ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) @@ -482,7 +482,7 @@ private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAut ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - return ProcessAuthenticationAsync(isAsync: true, isApm: true, cancellationToken); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) => @@ -520,7 +520,7 @@ public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAu SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) @@ -528,13 +528,13 @@ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAut SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - return ProcessAuthenticationAsync(isAsync: true, isApm: true, cancellationToken); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default) { ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state, _userCertificateValidationCallback)); - return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } public virtual Task ShutdownAsync() @@ -831,7 +831,7 @@ public override int ReadByte() ThrowIfExceptionalOrNotAuthenticated(); if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "ReadByte", "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } // If there's any data in the buffer, take one byte, and we're done. diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs index 9298542f607673..140792db4ec772 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs @@ -37,10 +37,11 @@ private void ValidateCreateContext(SslAuthenticationOptions sslAuthenticationOpt _sslAuthenticationOptions = new FakeOptions() { TargetHost = sslAuthenticationOptions.TargetHost }; } - private ValueTask WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) - where TWriteAdapter : struct, IReadWriteAdapter => default; + private ValueTask WriteAsyncInternal(ReadOnlyMemory buffer, CancellationToken cancellationToken) + where TWriteAdapter : IReadWriteAdapter => default; - private ValueTask ReadAsyncInternal(TReadAdapter adapter, Memory buffer) => default; + private ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) + where TReadAdapter : IReadWriteAdapter => default; private bool RemoteCertRequired => default; @@ -51,12 +52,12 @@ private void CloseInternal() // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private Task ProcessAuthenticationAsync(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) + private Task ProcessAuthenticationAsync(bool isAsync = false, CancellationToken cancellationToken = default) { return Task.Run(() => { }); } - private Task RenegotiateAsync(AsyncReadWriteAdapter adapter) => throw new PlatformNotSupportedException(); + private Task RenegotiateAsync(CancellationToken cancellationToken) => throw new PlatformNotSupportedException(); private void ReturnReadBufferIfEmpty() {