diff --git a/src/StackExchange.Redis/ChannelMessageQueue.cs b/src/StackExchange.Redis/ChannelMessageQueue.cs index 1a5b62baa..bd596fdaa 100644 --- a/src/StackExchange.Redis/ChannelMessageQueue.cs +++ b/src/StackExchange.Redis/ChannelMessageQueue.cs @@ -15,7 +15,7 @@ public readonly struct ChannelMessage /// /// See Object.ToString /// - public override string ToString() => ((string)Channel) + ":" + ((string)Message); + public override string ToString() => Channel + ":" + Message; /// /// See Object.GetHashCode @@ -28,7 +28,14 @@ public readonly struct ChannelMessage /// The to compare. public override bool Equals(object obj) => obj is ChannelMessage cm && cm.Channel == Channel && cm.Message == Message; - internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value) + + /// + /// Create a new representing a message written to a + /// + /// The associated with this message. + /// A identifying the channel from which the message was received. + /// A representing the value of the message. + public ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value) { _queue = queue; Channel = channel; @@ -56,50 +63,44 @@ internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisVa /// To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey) public sealed class ChannelMessageQueue { - private readonly Channel _queue; + private readonly ChannelReader _queue; + private readonly Action _onUnsubscribe; + private readonly Func _onUnsubscribeAsync; + private readonly Action _onInternalError; + /// /// The Channel that was subscribed for this queue /// public RedisChannel Channel { get; } - private RedisSubscriber _parent; /// /// See Object.ToString /// - public override string ToString() => (string)Channel; + public override string ToString() => Channel; /// /// An awaitable task the indicates completion of the queue (including drain of data) /// - public Task Completion => _queue.Reader.Completion; + public Task Completion => _queue.Completion; - internal ChannelMessageQueue(RedisChannel redisChannel, RedisSubscriber parent) + /// + /// Constructs a from a representing + /// incoming Redis values on the channel. + /// + /// The name of the channel this subscription is listening on. + /// A channel reader representing the incoming values. + /// A delegate to call when is called. + /// A delegate to call when is called. + /// REVIEW: Need more context here + public ChannelMessageQueue(RedisChannel channel, ChannelReader incomingValues, Action onUnsubscribe, Func onUnsubscribeAsync, Action onInternalError) { - Channel = redisChannel; - _parent = parent; - _queue = System.Threading.Channels.Channel.CreateUnbounded(s_ChannelOptions); - } + Channel = channel; + _queue = incomingValues; - private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions - { - SingleWriter = true, - SingleReader = false, - AllowSynchronousContinuations = false, - }; - internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags); - internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags); - - private void HandleMessage(RedisChannel channel, RedisValue value) - { - var writer = _queue.Writer; - if (channel.IsNull && value.IsNull) // see ForSyncShutdown - { - writer.TryComplete(); - } - else - { - writer.TryWrite(new ChannelMessage(this, channel, value)); - } + // REVIEW: This part is kind of hacky... + _onUnsubscribe = onUnsubscribe; + _onUnsubscribeAsync = onUnsubscribeAsync; + _onInternalError = onInternalError; } /// @@ -107,13 +108,14 @@ private void HandleMessage(RedisChannel channel, RedisValue value) /// /// The to use. public ValueTask ReadAsync(CancellationToken cancellationToken = default) - => _queue.Reader.ReadAsync(cancellationToken); + => _queue.ReadAsync(cancellationToken); /// /// Attempt to synchronously consume a message from the channel. /// /// The read from the Channel. - public bool TryRead(out ChannelMessage item) => _queue.Reader.TryRead(out item); + public bool TryRead(out ChannelMessage item) + => _queue.TryRead(out item); /// /// Attempt to query the backlog length of the queue. @@ -124,11 +126,16 @@ public bool TryGetCount(out int count) // get this using the reflection try { - var prop = _queue.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic); - if (prop != null) + var parentField = _queue.GetType().GetField("_parent", BindingFlags.Instance | BindingFlags.NonPublic); + if (parentField != null) { - count = (int)prop.GetValue(_queue); - return true; + var parent = parentField.GetValue(_queue); + var prop = parent.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic); + if (prop != null) + { + count = (int)prop.GetValue(parent); + return true; + } } } catch { } @@ -139,9 +146,15 @@ public bool TryGetCount(out int count) private Delegate _onMessageHandler; private void AssertOnMessage(Delegate handler) { - if (handler == null) throw new ArgumentNullException(nameof(handler)); + if (handler == null) + { + throw new ArgumentNullException(nameof(handler)); + } + if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null) + { throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed"); + } } /// @@ -162,17 +175,29 @@ private async Task OnMessageSyncImpl() while (!Completion.IsCompleted) { ChannelMessage next; - try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); } - catch (ChannelClosedException) { break; } // expected - catch (Exception ex) + + // Keep trying to read values + while (!TryRead(out next)) { - _parent.multiplexer?.OnInternalError(ex); - break; + // If we fail, wait for an item to appear + if (!await _queue.WaitToReadAsync()) + { + // Channel is closed + break; + } + + // There should be an item available now, but another reader might grab it, + // so we keep TryReading in the loop. } try { handler(next); } catch { } // matches MessageCompletable } + + if (Completion.IsFaulted) + { + _onInternalError(Completion.Exception.InnerException); + } } /// @@ -193,68 +218,63 @@ private async Task OnMessageAsyncImpl() while (!Completion.IsCompleted) { ChannelMessage next; - try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); } - catch (ChannelClosedException) { break; } // expected - catch (Exception ex) + + // Keep trying to read values + while (!TryRead(out next)) { - _parent.multiplexer?.OnInternalError(ex); - break; + // If we fail, wait for an item to appear + if (!await _queue.WaitToReadAsync()) + { + // Channel is closed + break; + } + + // There should be an item available now, but another reader might grab it, + // so we keep TryReading in the loop. } try { var task = handler(next); - if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ConfigureAwait(false); + if (task != null && task.Status != TaskStatus.RanToCompletion) + { + await task.ConfigureAwait(false); + } } catch { } // matches MessageCompletable } - } - internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None) - { - var parent = _parent; - _parent = null; - if (parent != null) + if (Completion.IsFaulted) { - parent.UnsubscribeAsync(Channel, HandleMessage, flags); + _onInternalError(Completion.Exception.InnerException); } - _queue.Writer.TryComplete(error); - } - - internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags flags = CommandFlags.None) - { - var parent = _parent; - _parent = null; - if (parent != null) - { - await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ConfigureAwait(false); - } - _queue.Writer.TryComplete(error); } internal static bool IsOneOf(Action handler) { - try - { - return handler?.Target is ChannelMessageQueue - && handler.Method.Name == nameof(HandleMessage); - } - catch - { - return false; - } + // REVIEW: Need more context here to properly replace this. + throw new NotImplementedException(); + //try + //{ + // return handler?.Target is ChannelMessageQueue + // && handler.Method.Name == nameof(HandleMessage); + //} + //catch + //{ + // return false; + //} } /// /// Stop receiving messages on this channel. /// /// The flags to use when unsubscribing. - public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags); + public void Unsubscribe(CommandFlags flags = CommandFlags.None) => _onUnsubscribe(flags); /// /// Stop receiving messages on this channel. /// /// The flags to use when unsubscribing. - public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags); + public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => _onUnsubscribeAsync(flags); } } diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 25d30969d..b5fa562eb 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -4,6 +4,7 @@ using System.Net; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; namespace StackExchange.Redis @@ -297,6 +298,13 @@ internal void OnTransactionLog(string message) internal sealed class RedisSubscriber : RedisBase, ISubscriber { + private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions + { + SingleWriter = true, + SingleReader = false, + AllowSynchronousContinuations = false, + }; + internal RedisSubscriber(ConnectionMultiplexer multiplexer, object asyncState) : base(multiplexer, asyncState) { } @@ -376,22 +384,23 @@ public void Subscribe(RedisChannel channel, Action han public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None) { - var c = new ChannelMessageQueue(channel, this); - c.Subscribe(flags); - return c; + var (queue, handler) = CreateMessageQueue(channel); + Subscribe(channel, handler, flags); + return queue; } public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) { if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); + return multiplexer.AddSubscription(channel, handler, flags, asyncState); } public async Task SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None) { - var c = new ChannelMessageQueue(channel, this); - await c.SubscribeAsync(flags).ForAwait(); - return c; + var (queue, handler) = CreateMessageQueue(channel); + await SubscribeAsync(channel, handler, flags).ForAwait(); + return queue; } public EndPoint SubscribedEndpoint(RedisChannel channel) @@ -420,7 +429,39 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None) public Task UnsubscribeAsync(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) { if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); + return multiplexer.RemoveSubscription(channel, handler, flags, asyncState); } + + internal (ChannelMessageQueue queue, Action handler) CreateMessageQueue(RedisChannel channel) + { + var queue = Channel.CreateUnbounded(s_ChannelOptions); + + // We need to create the variable before constructing the queue so that the Handler function below can capture it. + // It'll be fully assigned before it's used since the delegate is invoked by the Unsubscribe instance method, + // which isn't available until the value is fully initialized. + ChannelMessageQueue messageQueue = null; + messageQueue = new ChannelMessageQueue( + channel, + queue.Reader, + (f) => Unsubscribe(channel, Handler, f), + (f) => UnsubscribeAsync(channel, Handler, f), + (ex) => multiplexer?.OnInternalError(ex)); + + void Handler(RedisChannel c, RedisValue v) + { + if (c.IsNull && v.IsNull) + { + queue.Writer.TryComplete(); + } + else + { + var wrote = queue.Writer.TryWrite(new ChannelMessage(messageQueue, c, v)); + Debug.Assert(wrote, "Queue should be unbounded!"); + } + } + + return (messageQueue, Handler); + } } } diff --git a/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs b/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs new file mode 100644 index 000000000..d2060cb5c --- /dev/null +++ b/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs @@ -0,0 +1,220 @@ +using System; +using System.Threading.Channels; +using System.Threading.Tasks; +using Xunit; + +namespace StackExchange.Redis.Tests +{ + public class ChannelMessageQueueTests + { + [Fact] + public void ItemsYieldedToChannelCanBeRead() + { + var channel = Channel.CreateUnbounded(); + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test"))); + Assert.True(queue.TryRead(out var message)); + Assert.Equal("TestChannel", message.Channel); + Assert.Equal("Test", message.Message); + Assert.Equal("TestChannel", message.SubscriptionChannel); + } + + [Fact] + public void ActualChannelIsProvidedInChannelMessageChanelProperty() + { + var channel = Channel.CreateUnbounded(); + var queue = new ChannelMessageQueue( + "TestChannel.*", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel.A", "Test"))); + Assert.True(queue.TryRead(out var message)); + Assert.Equal("TestChannel.A", message.Channel); + Assert.Equal("Test", message.Message); + Assert.Equal("TestChannel.*", message.SubscriptionChannel); + } + + [Fact] + public void ReadAsyncYieldsItemWhenOneIsAvailable() + { + var channel = Channel.CreateUnbounded(); + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + var readTask = queue.ReadAsync(); + + Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test"))); + Assert.True(readTask.IsCompleted); + var message = readTask.GetAwaiter().GetResult(); + + Assert.Equal("TestChannel", message.Channel); + Assert.Equal("Test", message.Message); + Assert.Equal("TestChannel", message.SubscriptionChannel); + } + + [Fact] + public void UnsubscribeSignalsProvidedDelegate() + { + var channel = Channel.CreateUnbounded(); + CommandFlags? unsubscribeFlags = null; + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => unsubscribeFlags = f, + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + queue.Unsubscribe(CommandFlags.FireAndForget); + + Assert.Equal(CommandFlags.FireAndForget, unsubscribeFlags); + } + + [Fact] + public void UnsubscribeAsyncSignalsProvidedDelegate() + { + var channel = Channel.CreateUnbounded(); + CommandFlags? unsubscribeFlags = null; + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => + { + unsubscribeFlags = f; + return Task.CompletedTask; + }, + ex => throw new NotImplementedException()); + + Assert.True(queue.UnsubscribeAsync(CommandFlags.FireAndForget).IsCompleted); + Assert.Equal(CommandFlags.FireAndForget, unsubscribeFlags); + } + + [Fact] + public async Task OnMessageCreatesMessageLoop() + { + var channel = Channel.CreateUnbounded(); + var received = Channel.CreateUnbounded(); + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + // Start the message loop + // NOTE: The loop runs on the thread pool + queue.OnMessage(m => + { + Assert.True(received.Writer.TryWrite(m)); + }); + + // Write an item and verify that it comes through the loop + Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test"))); + + var message = await received.Reader.ReadAsync(); + Assert.Equal("TestChannel", message.Channel); + Assert.Equal("Test", message.Message); + Assert.Equal("TestChannel", message.SubscriptionChannel); + + // Shut down the loop. + channel.Writer.TryComplete(); + } + + [Fact] + public async Task OnMessageAsyncCreatesMessageLoop() + { + var channel = Channel.CreateUnbounded(); + var received = Channel.CreateUnbounded(); + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => throw new NotImplementedException()); + + // Start the message loop + // NOTE: The loop runs on the thread pool + queue.OnMessage(m => + { + Assert.True(received.Writer.TryWrite(m)); + return Task.CompletedTask; + }); + + // Write an item and verify that it comes through the loop + Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test"))); + + var message = await received.Reader.ReadAsync(); + Assert.Equal("TestChannel", message.Channel); + Assert.Equal("Test", message.Message); + Assert.Equal("TestChannel", message.SubscriptionChannel); + + // Shut down the loop. + channel.Writer.TryComplete(); + } + + [Fact] + public async Task CompletionExceptionGoesToInternalErrorHandlerWhenUsingOnMessage() + { + var channel = Channel.CreateUnbounded(); + var received = Channel.CreateUnbounded(); + var errorOccurred = new TaskCompletionSource(); + + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => errorOccurred.TrySetResult(ex)); + + // Start the message loop + // NOTE: The loop runs on the thread pool + queue.OnMessage(m => + { + throw new NotImplementedException(); + }); + + // Complete with an error, triggering the exception + channel.Writer.TryComplete(new Exception("BARF!")); + Assert.Equal("BARF!", (await errorOccurred.Task).Message); + } + + [Fact] + public async Task CompletionExceptionGoesToInternalErrorHandlerWhenUsingOnMessageAsync() + { + var channel = Channel.CreateUnbounded(new UnboundedChannelOptions()); + var received = Channel.CreateUnbounded(); + var errorOccurred = new TaskCompletionSource(); + + var queue = new ChannelMessageQueue( + "TestChannel", + channel.Reader, + (f) => throw new NotImplementedException(), + (f) => throw new NotImplementedException(), + ex => errorOccurred.TrySetResult(ex)); + + // Start the message loop + // NOTE: The loop runs on the thread pool + queue.OnMessage(m => + { + return Task.FromException(new NotImplementedException()); + }); + + // Complete with an error, triggering the exception + channel.Writer.TryComplete(new Exception("BARF!")); + Assert.Equal("BARF!", (await errorOccurred.Task).Message); + } + } +}