From c31200efe6751b28972e8446c42204f44835d4c8 Mon Sep 17 00:00:00 2001 From: MattyLeslie Date: Thu, 30 May 2024 13:56:24 +0200 Subject: [PATCH 1/3] Ensuring adequate handling of cancellation tokens in ReadAsync methods --- .../Server/src/Circuits/RemoteJSDataStream.cs | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs index 75c432a9ec74..3abc1a42c6bc 100644 --- a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs +++ b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs @@ -181,28 +181,18 @@ public override void Write(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var linkedCancellationToken = GetLinkedCancellationToken(_streamCancellationToken, cancellationToken); - return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCancellationToken); + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_streamCancellationToken, cancellationToken)) + { + return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCts.Token); + } } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - var linkedCancellationToken = GetLinkedCancellationToken(_streamCancellationToken, cancellationToken); - return await _pipeReaderStream.ReadAsync(buffer, linkedCancellationToken); - } - - private static CancellationToken GetLinkedCancellationToken(CancellationToken a, CancellationToken b) - { - if (a.CanBeCanceled && b.CanBeCanceled) - { - return CancellationTokenSource.CreateLinkedTokenSource(a, b).Token; - } - else if (a.CanBeCanceled) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_streamCancellationToken, cancellationToken)) { - return a; + return await _pipeReaderStream.ReadAsync(buffer, linkedCts.Token); } - - return b; } private async Task ThrowOnTimeout() From 708c0fdea6d42168872e7dd041ccf329b88b9b02 Mon Sep 17 00:00:00 2001 From: MattyLeslie Date: Thu, 30 May 2024 13:57:39 +0200 Subject: [PATCH 2/3] Writing tests "ReadAsync_Memory_DisposesCancellationTokenSource", && "ReadAsync_ByteArray_DisposesCancellationTokenSource" --- .../test/Circuits/RemoteJSDataStreamTest.cs | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs index dac28c3acea4..41e4a5744aff 100644 --- a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs +++ b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs @@ -287,6 +287,56 @@ public async Task ReceiveData_ReceivesDataThenTimesout_StreamDisposed() Assert.False(success); } + [Fact] + public async Task ReadAsync_ByteArray_DisposesCancellationTokenSource() + { + // Arrange + var jsStreamReference = Mock.Of(); + var remoteJSDataStream = await RemoteJSDataStream.CreateRemoteJSDataStreamAsync(_jsRuntime, jsStreamReference, totalLength: 100, signalRMaximumIncomingBytes: 10_000, jsInteropDefaultCallTimeout: TimeSpan.FromMinutes(1), cancellationToken: CancellationToken.None).DefaultTimeout(); + var buffer = new byte[100]; + var chunk = new byte[100]; + new Random().NextBytes(chunk); + + // Act + var cts = new CancellationTokenSource(); + var sendDataTask = Task.Run(async () => + { + await RemoteJSDataStream.ReceiveData(_jsRuntime, GetStreamId(remoteJSDataStream, _jsRuntime), chunkId: 0, chunk, error: null); + }); + + await sendDataTask; + var result = await remoteJSDataStream.ReadAsync(buffer, 0, buffer.Length, cts.Token); + + // Assert + Assert.True(cts.Token.CanBeCanceled); + Assert.False(cts.IsCancellationRequested); + } + + [Fact] + public async Task ReadAsync_Memory_DisposesCancellationTokenSource() + { + // Arrange + var jsStreamReference = Mock.Of(); + var remoteJSDataStream = await RemoteJSDataStream.CreateRemoteJSDataStreamAsync(_jsRuntime, jsStreamReference, totalLength: 100, signalRMaximumIncomingBytes: 10_000, jsInteropDefaultCallTimeout: TimeSpan.FromMinutes(1), cancellationToken: CancellationToken.None).DefaultTimeout(); + var buffer = new Memory(new byte[100]); + var chunk = new byte[100]; + new Random().NextBytes(chunk); + + // Act + var cts = new CancellationTokenSource(); + var sendDataTask = Task.Run(async () => + { + await RemoteJSDataStream.ReceiveData(_jsRuntime, GetStreamId(remoteJSDataStream, _jsRuntime), chunkId: 0, chunk, error: null); + }); + + await sendDataTask; + var result = await remoteJSDataStream.ReadAsync(buffer, cts.Token); + + // Assert + Assert.True(cts.Token.CanBeCanceled); + Assert.False(cts.IsCancellationRequested); + } + private static async Task CreateRemoteJSDataStreamAsync(TestRemoteJSRuntime jsRuntime = null) { var jsStreamReference = Mock.Of(); From 844215d1d210aaba83ecbc5a5d918cb7f839ae56 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Mon, 21 Oct 2024 10:52:02 -0700 Subject: [PATCH 3/3] Suggestions from code review --- .../Server/src/Circuits/RemoteJSDataStream.cs | 53 +++++++++++-- .../test/Circuits/RemoteJSDataStreamTest.cs | 79 ++++++++++--------- 2 files changed, 87 insertions(+), 45 deletions(-) diff --git a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs index 3abc1a42c6bc..d1cdf0724832 100644 --- a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs +++ b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs @@ -181,18 +181,14 @@ public override void Write(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_streamCancellationToken, cancellationToken)) - { - return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCts.Token); - } + using var linkedCts = ValueLinkedCancellationTokenSource.Create(_streamCancellationToken, cancellationToken); + return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCts.Token); } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_streamCancellationToken, cancellationToken)) - { - return await _pipeReaderStream.ReadAsync(buffer, linkedCts.Token); - } + using var linkedCts = ValueLinkedCancellationTokenSource.Create(_streamCancellationToken, cancellationToken); + return await _pipeReaderStream.ReadAsync(buffer, linkedCts.Token); } private async Task ThrowOnTimeout() @@ -233,4 +229,45 @@ protected override void Dispose(bool disposing) _disposed = true; } + + // A helper for creating and disposing linked CancellationTokenSources + // without allocating, when possible. + // Internal for testing. + internal readonly struct ValueLinkedCancellationTokenSource : IDisposable + { + private readonly CancellationTokenSource? _linkedCts; + + public readonly CancellationToken Token; + + // For testing. + internal bool HasLinkedCancellationTokenSource => _linkedCts is not null; + + public static ValueLinkedCancellationTokenSource Create( + CancellationToken token1, CancellationToken token2) + { + if (!token1.CanBeCanceled) + { + return new(linkedCts: null, token2); + } + + if (!token2.CanBeCanceled) + { + return new(linkedCts: null, token1); + } + + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); + return new(linkedCts, linkedCts.Token); + } + + private ValueLinkedCancellationTokenSource(CancellationTokenSource? linkedCts, CancellationToken token) + { + _linkedCts = linkedCts; + Token = token; + } + + public void Dispose() + { + _linkedCts?.Dispose(); + } + } } diff --git a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs index 41e4a5744aff..e737cedbe7c3 100644 --- a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs +++ b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs @@ -287,54 +287,59 @@ public async Task ReceiveData_ReceivesDataThenTimesout_StreamDisposed() Assert.False(success); } - [Fact] - public async Task ReadAsync_ByteArray_DisposesCancellationTokenSource() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ValueLinkedCts_Works_WhenOneTokenCannotBeCanceled(bool isToken1Cancelable) { - // Arrange - var jsStreamReference = Mock.Of(); - var remoteJSDataStream = await RemoteJSDataStream.CreateRemoteJSDataStreamAsync(_jsRuntime, jsStreamReference, totalLength: 100, signalRMaximumIncomingBytes: 10_000, jsInteropDefaultCallTimeout: TimeSpan.FromMinutes(1), cancellationToken: CancellationToken.None).DefaultTimeout(); - var buffer = new byte[100]; - var chunk = new byte[100]; - new Random().NextBytes(chunk); - - // Act var cts = new CancellationTokenSource(); - var sendDataTask = Task.Run(async () => - { - await RemoteJSDataStream.ReceiveData(_jsRuntime, GetStreamId(remoteJSDataStream, _jsRuntime), chunkId: 0, chunk, error: null); - }); + var token1 = isToken1Cancelable ? cts.Token : CancellationToken.None; + var token2 = isToken1Cancelable ? CancellationToken.None : cts.Token; - await sendDataTask; - var result = await remoteJSDataStream.ReadAsync(buffer, 0, buffer.Length, cts.Token); + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create(token1, token2); - // Assert - Assert.True(cts.Token.CanBeCanceled); - Assert.False(cts.IsCancellationRequested); + Assert.False(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + + cts.Cancel(); + + Assert.True(linkedCts.Token.IsCancellationRequested); } [Fact] - public async Task ReadAsync_Memory_DisposesCancellationTokenSource() + public void ValueLinkedCts_Works_WhenBothTokensCannotBeCanceled() { - // Arrange - var jsStreamReference = Mock.Of(); - var remoteJSDataStream = await RemoteJSDataStream.CreateRemoteJSDataStreamAsync(_jsRuntime, jsStreamReference, totalLength: 100, signalRMaximumIncomingBytes: 10_000, jsInteropDefaultCallTimeout: TimeSpan.FromMinutes(1), cancellationToken: CancellationToken.None).DefaultTimeout(); - var buffer = new Memory(new byte[100]); - var chunk = new byte[100]; - new Random().NextBytes(chunk); + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create( + CancellationToken.None, + CancellationToken.None); - // Act - var cts = new CancellationTokenSource(); - var sendDataTask = Task.Run(async () => - { - await RemoteJSDataStream.ReceiveData(_jsRuntime, GetStreamId(remoteJSDataStream, _jsRuntime), chunkId: 0, chunk, error: null); - }); + Assert.False(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + } - await sendDataTask; - var result = await remoteJSDataStream.ReadAsync(buffer, cts.Token); + [Theory] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public void ValueLinkedCts_Works_WhenBothTokensCanBeCanceled(bool shouldCancelToken1, bool shouldCancelToken2) + { + var cts1 = new CancellationTokenSource(); + var cts2 = new CancellationTokenSource(); + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create(cts1.Token, cts2.Token); - // Assert - Assert.True(cts.Token.CanBeCanceled); - Assert.False(cts.IsCancellationRequested); + Assert.True(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + + if (shouldCancelToken1) + { + cts1.Cancel(); + } + if (shouldCancelToken2) + { + cts2.Cancel(); + } + + Assert.True(linkedCts.Token.IsCancellationRequested); } private static async Task CreateRemoteJSDataStreamAsync(TestRemoteJSRuntime jsRuntime = null)