Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
<value>The connection was closed</value>
</data>
<data name="net_io_invalidnestedcall" xml:space="preserve">
<value> The {0} method cannot be called when another {1} operation is pending.</value>
<value> This method may not be called when another {0} operation is pending.</value>
</data>
<data name="net_io_invalidendcall" xml:space="preserve">
<value>{0} can only be called once for each asynchronous operation.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public virtual void AuthenticateAsServer(NetworkCredential credential, Protectio
public virtual void AuthenticateAsServer(NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
{
ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult();
AuthenticateAsync<SyncReadWriteAdapter>(default(CancellationToken)).GetAwaiter().GetResult();
}

public virtual IAsyncResult BeginAuthenticateAsServer(AsyncCallback? asyncCallback, object? asyncState) =>
Expand Down Expand Up @@ -172,7 +172,7 @@ public virtual void AuthenticateAsClient(
NetworkCredential credential, ChannelBinding? binding, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel)
{
ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult();
AuthenticateAsync<SyncReadWriteAdapter>(default(CancellationToken)).GetAwaiter().GetResult();
}

public virtual Task AuthenticateAsClientAsync() =>
Expand All @@ -195,7 +195,7 @@ public virtual Task AuthenticateAsClientAsync(
TokenImpersonationLevel allowedImpersonationLevel)
{
ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default));
return AuthenticateAsync<AsyncReadWriteAdapter>(default(CancellationToken));
}

public virtual Task AuthenticateAsServerAsync() =>
Expand All @@ -211,7 +211,7 @@ public virtual Task AuthenticateAsServerAsync(
NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
{
ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default));
return AuthenticateAsync<AsyncReadWriteAdapter>(default(CancellationToken));
}

public override bool IsAuthenticated => IsAuthenticatedCore;
Expand Down Expand Up @@ -312,7 +312,7 @@ public override int Read(byte[] buffer, int offset, int count)
return InnerStream.Read(buffer, offset, count);
}

ValueTask<int> vt = ReadAsync(new SyncReadWriteAdapter(InnerStream), new Memory<byte>(buffer, offset, count));
ValueTask<int> vt = ReadAsync<SyncReadWriteAdapter>(new Memory<byte>(buffer, offset, count), default(CancellationToken));
Debug.Assert(vt.IsCompleted, "Should have completed synchroously with sync adapter");
return vt.GetAwaiter().GetResult();
}
Expand All @@ -327,7 +327,7 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
return InnerStream.ReadAsync(buffer, offset, count, cancellationToken);
}

return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new Memory<byte>(buffer, offset, count)).AsTask();
return ReadAsync<AsyncReadWriteAdapter>(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
Expand All @@ -338,14 +338,15 @@ public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken
return InnerStream.ReadAsync(buffer, cancellationToken);
}

return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer);
return ReadAsync<AsyncReadWriteAdapter>(buffer, cancellationToken);
}

private async ValueTask<int> ReadAsync<TAdapter>(TAdapter adapter, Memory<byte> buffer, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter
private async ValueTask<int> ReadAsync<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _readInProgress, 1) == 1)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, callerName, "read"));
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}

try
Expand All @@ -364,7 +365,7 @@ private async ValueTask<int> ReadAsync<TAdapter>(TAdapter adapter, Memory<byte>

while (true)
{
int readBytes = await ReadAllAsync(adapter, _readHeader, allowZeroRead: true).ConfigureAwait(false);
int readBytes = await ReadAllAsync(InnerStream, _readHeader, allowZeroRead: true, cancellationToken).ConfigureAwait(false);
if (readBytes == 0)
{
return 0;
Expand All @@ -389,7 +390,7 @@ private async ValueTask<int> ReadAsync<TAdapter>(TAdapter adapter, Memory<byte>
_readBuffer = new byte[readBytes];
}

readBytes = await ReadAllAsync(adapter, new Memory<byte>(_readBuffer, 0, readBytes), allowZeroRead: false).ConfigureAwait(false);
readBytes = await ReadAllAsync(InnerStream, new Memory<byte>(_readBuffer, 0, readBytes), allowZeroRead: false, cancellationToken).ConfigureAwait(false);

// Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
// Decrypted data start from zero offset, the size can be shrunk after decryption.
Expand Down Expand Up @@ -421,13 +422,13 @@ private async ValueTask<int> ReadAsync<TAdapter>(TAdapter adapter, Memory<byte>
_readInProgress = 0;
}

static async ValueTask<int> ReadAllAsync(TAdapter adapter, Memory<byte> buffer, bool allowZeroRead)
static async ValueTask<int> ReadAllAsync(Stream stream, Memory<byte> buffer, bool allowZeroRead, CancellationToken cancellationToken)
{
int read = 0;

do
{
int bytes = await adapter.ReadAsync(buffer).ConfigureAwait(false);
int bytes = await TIOAdapter.ReadAsync(stream, buffer, cancellationToken).ConfigureAwait(false);
if (bytes == 0)
{
if (read != 0 || !allowZeroRead)
Expand Down Expand Up @@ -457,7 +458,7 @@ public override void Write(byte[] buffer, int offset, int count)
return;
}

WriteAsync(new SyncReadWriteAdapter(InnerStream), new ReadOnlyMemory<byte>(buffer, offset, count)).GetAwaiter().GetResult();
WriteAsync<SyncReadWriteAdapter>(new ReadOnlyMemory<byte>(buffer, offset, count), default(CancellationToken)).GetAwaiter().GetResult();
}

/// <returns>A <see cref="Task"/> that represents the asynchronous read operation.</returns>
Expand All @@ -471,7 +472,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return InnerStream.WriteAsync(buffer, offset, count, cancellationToken);
}

return WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new ReadOnlyMemory<byte>(buffer, offset, count));
return WriteAsync<AsyncReadWriteAdapter>(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);
}

/// <returns>A <see cref="ValueTask"/> that represents the asynchronous read operation.</returns>
Expand All @@ -483,14 +484,15 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo
return InnerStream.WriteAsync(buffer, cancellationToken);
}

return new ValueTask(WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer));
return new ValueTask(WriteAsync<AsyncReadWriteAdapter>(buffer, cancellationToken));
}

private async Task WriteAsync<TAdapter>(TAdapter adapter, ReadOnlyMemory<byte> buffer) where TAdapter : IReadWriteAdapter
private async Task WriteAsync<TIOAdapter>(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _writeInProgress, 1) == 1)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(Write), "write"));
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}

try
Expand All @@ -508,7 +510,7 @@ private async Task WriteAsync<TAdapter>(TAdapter adapter, ReadOnlyMemory<byte> b
throw new IOException(SR.net_io_encrypt, e);
}

await adapter.WriteAsync(_writeBuffer, 0, encryptedBytes).ConfigureAwait(false);
await TIOAdapter.WriteAsync(InnerStream, _writeBuffer, 0, encryptedBytes, cancellationToken).ConfigureAwait(false);
buffer = buffer.Slice(chunkBytes);
}
}
Expand Down Expand Up @@ -702,21 +704,22 @@ private void ThrowIfFailed(bool authSuccessCheck)
}
}

private async Task AuthenticateAsync<TAdapter>(TAdapter adapter, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter
private async Task AuthenticateAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
Debug.Assert(_context != null);

ThrowIfFailed(authSuccessCheck: false);
if (Interlocked.Exchange(ref _authInProgress, 1) == 1)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, callerName, "authenticate"));
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}

try
{
await (_context.IsServer ?
ReceiveBlobAsync(adapter) : // server should listen for a client blob
SendBlobAsync(adapter, message: null)).ConfigureAwait(false); // client should send the first blob
ReceiveBlobAsync<TIOAdapter>(cancellationToken) : // server should listen for a client blob
SendBlobAsync<TIOAdapter>(message: null, cancellationToken)).ConfigureAwait(false); // client should send the first blob
}
catch (Exception e)
{
Expand Down Expand Up @@ -751,7 +754,8 @@ private bool CheckSpn()
}

// Client authentication starts here, but server also loops through this method.
private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) where TAdapter : IReadWriteAdapter
private async Task SendBlobAsync<TIOAdapter>(byte[]? message, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
Debug.Assert(_context != null);

Expand All @@ -764,7 +768,7 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
if (exception != null)
{
// Signal remote side on a failed attempt.
await SendAuthResetSignalAndThrowAsync(adapter, message!, exception).ConfigureAwait(false);
await SendAuthResetSignalAndThrowAsync<TIOAdapter>(message!, exception, cancellationToken).ConfigureAwait(false);
Debug.Fail("Unreachable");
}

Expand All @@ -782,7 +786,7 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
statusCode = (int)((uint)statusCode >> 8);
}

await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
await SendAuthResetSignalAndThrowAsync<TIOAdapter>(message, exception, cancellationToken).ConfigureAwait(false);
Debug.Fail("Unreachable");
}

Expand All @@ -798,7 +802,7 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
statusCode = (int)((uint)statusCode >> 8);
}

await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
await SendAuthResetSignalAndThrowAsync<TIOAdapter>(message, exception, cancellationToken).ConfigureAwait(false);
Debug.Fail("Unreachable");
}

Expand All @@ -816,7 +820,7 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
statusCode = (int)((uint)statusCode >> 8);
}

await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
await SendAuthResetSignalAndThrowAsync<TIOAdapter>(message, exception, cancellationToken).ConfigureAwait(false);
Debug.Fail("Unreachable");
}

Expand All @@ -840,7 +844,7 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
if (message != null)
{
//even if we are completed, there could be a blob for sending.
await _framer!.WriteMessageAsync(adapter, message).ConfigureAwait(false);
await _framer!.WriteMessageAsync<TIOAdapter>(InnerStream, message, cancellationToken).ConfigureAwait(false);
}

if (HandshakeComplete && _remoteOk)
Expand All @@ -849,15 +853,16 @@ private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) wh
return;
}

await ReceiveBlobAsync(adapter).ConfigureAwait(false);
await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
}

// Server authentication starts here, but client also loops through this method.
private async Task ReceiveBlobAsync<TAdapter>(TAdapter adapter) where TAdapter : IReadWriteAdapter
private async Task ReceiveBlobAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
Debug.Assert(_framer != null);

byte[]? message = await _framer.ReadMessageAsync(adapter).ConfigureAwait(false);
byte[]? message = await _framer.ReadMessageAsync<TIOAdapter>(InnerStream, cancellationToken).ConfigureAwait(false);
if (message == null)
{
// This is an EOF otherwise we would get at least *empty* message but not a null one.
Expand Down Expand Up @@ -903,12 +908,13 @@ private async Task ReceiveBlobAsync<TAdapter>(TAdapter adapter) where TAdapter :
}

// Not yet done, get a new blob and send it if any.
await SendBlobAsync(adapter, message).ConfigureAwait(false);
await SendBlobAsync<TIOAdapter>(message, cancellationToken).ConfigureAwait(false);
}

// This is to reset auth state on the remote side.
// If this write succeeds we will allow auth retrying.
private async Task SendAuthResetSignalAndThrowAsync<TAdapter>(TAdapter adapter, byte[] message, Exception exception) where TAdapter : IReadWriteAdapter
private async Task SendAuthResetSignalAndThrowAsync<TIOAdapter>(byte[] message, Exception exception, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
_framer!.WriteHeader.MessageId = FrameHeader.HandshakeErrId;

Expand All @@ -922,7 +928,7 @@ private async Task SendAuthResetSignalAndThrowAsync<TAdapter>(TAdapter adapter,
exception = new AuthenticationException(SR.net_auth_SSPI, exception);
}

await _framer.WriteMessageAsync(adapter, message).ConfigureAwait(false);
await _framer.WriteMessageAsync<TIOAdapter>(InnerStream, message, cancellationToken).ConfigureAwait(false);

_canRetryAuthentication = true;
ExceptionDispatchInfo.Throw(exception);
Expand Down
Loading