diff --git a/src/StackExchange.Redis/ChannelMessageQueue.cs b/src/StackExchange.Redis/ChannelMessageQueue.cs
index 1a5b62baa..bd596fdaa 100644
--- a/src/StackExchange.Redis/ChannelMessageQueue.cs
+++ b/src/StackExchange.Redis/ChannelMessageQueue.cs
@@ -15,7 +15,7 @@ public readonly struct ChannelMessage
///
/// See Object.ToString
///
- public override string ToString() => ((string)Channel) + ":" + ((string)Message);
+ public override string ToString() => Channel + ":" + Message;
///
/// See Object.GetHashCode
@@ -28,7 +28,14 @@ public readonly struct ChannelMessage
/// The to compare.
public override bool Equals(object obj) => obj is ChannelMessage cm
&& cm.Channel == Channel && cm.Message == Message;
- internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value)
+
+ ///
+ /// Create a new representing a message written to a
+ ///
+ /// The associated with this message.
+ /// A identifying the channel from which the message was received.
+ /// A representing the value of the message.
+ public ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value)
{
_queue = queue;
Channel = channel;
@@ -56,50 +63,44 @@ internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisVa
/// To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey)
public sealed class ChannelMessageQueue
{
- private readonly Channel _queue;
+ private readonly ChannelReader _queue;
+ private readonly Action _onUnsubscribe;
+ private readonly Func _onUnsubscribeAsync;
+ private readonly Action _onInternalError;
+
///
/// The Channel that was subscribed for this queue
///
public RedisChannel Channel { get; }
- private RedisSubscriber _parent;
///
/// See Object.ToString
///
- public override string ToString() => (string)Channel;
+ public override string ToString() => Channel;
///
/// An awaitable task the indicates completion of the queue (including drain of data)
///
- public Task Completion => _queue.Reader.Completion;
+ public Task Completion => _queue.Completion;
- internal ChannelMessageQueue(RedisChannel redisChannel, RedisSubscriber parent)
+ ///
+ /// Constructs a from a representing
+ /// incoming Redis values on the channel.
+ ///
+ /// The name of the channel this subscription is listening on.
+ /// A channel reader representing the incoming values.
+ /// A delegate to call when is called.
+ /// A delegate to call when is called.
+ /// REVIEW: Need more context here
+ public ChannelMessageQueue(RedisChannel channel, ChannelReader incomingValues, Action onUnsubscribe, Func onUnsubscribeAsync, Action onInternalError)
{
- Channel = redisChannel;
- _parent = parent;
- _queue = System.Threading.Channels.Channel.CreateUnbounded(s_ChannelOptions);
- }
+ Channel = channel;
+ _queue = incomingValues;
- private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
- {
- SingleWriter = true,
- SingleReader = false,
- AllowSynchronousContinuations = false,
- };
- internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags);
- internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags);
-
- private void HandleMessage(RedisChannel channel, RedisValue value)
- {
- var writer = _queue.Writer;
- if (channel.IsNull && value.IsNull) // see ForSyncShutdown
- {
- writer.TryComplete();
- }
- else
- {
- writer.TryWrite(new ChannelMessage(this, channel, value));
- }
+ // REVIEW: This part is kind of hacky...
+ _onUnsubscribe = onUnsubscribe;
+ _onUnsubscribeAsync = onUnsubscribeAsync;
+ _onInternalError = onInternalError;
}
///
@@ -107,13 +108,14 @@ private void HandleMessage(RedisChannel channel, RedisValue value)
///
/// The to use.
public ValueTask ReadAsync(CancellationToken cancellationToken = default)
- => _queue.Reader.ReadAsync(cancellationToken);
+ => _queue.ReadAsync(cancellationToken);
///
/// Attempt to synchronously consume a message from the channel.
///
/// The read from the Channel.
- public bool TryRead(out ChannelMessage item) => _queue.Reader.TryRead(out item);
+ public bool TryRead(out ChannelMessage item)
+ => _queue.TryRead(out item);
///
/// Attempt to query the backlog length of the queue.
@@ -124,11 +126,16 @@ public bool TryGetCount(out int count)
// get this using the reflection
try
{
- var prop = _queue.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic);
- if (prop != null)
+ var parentField = _queue.GetType().GetField("_parent", BindingFlags.Instance | BindingFlags.NonPublic);
+ if (parentField != null)
{
- count = (int)prop.GetValue(_queue);
- return true;
+ var parent = parentField.GetValue(_queue);
+ var prop = parent.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic);
+ if (prop != null)
+ {
+ count = (int)prop.GetValue(parent);
+ return true;
+ }
}
}
catch { }
@@ -139,9 +146,15 @@ public bool TryGetCount(out int count)
private Delegate _onMessageHandler;
private void AssertOnMessage(Delegate handler)
{
- if (handler == null) throw new ArgumentNullException(nameof(handler));
+ if (handler == null)
+ {
+ throw new ArgumentNullException(nameof(handler));
+ }
+
if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null)
+ {
throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed");
+ }
}
///
@@ -162,17 +175,29 @@ private async Task OnMessageSyncImpl()
while (!Completion.IsCompleted)
{
ChannelMessage next;
- try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); }
- catch (ChannelClosedException) { break; } // expected
- catch (Exception ex)
+
+ // Keep trying to read values
+ while (!TryRead(out next))
{
- _parent.multiplexer?.OnInternalError(ex);
- break;
+ // If we fail, wait for an item to appear
+ if (!await _queue.WaitToReadAsync())
+ {
+ // Channel is closed
+ break;
+ }
+
+ // There should be an item available now, but another reader might grab it,
+ // so we keep TryReading in the loop.
}
try { handler(next); }
catch { } // matches MessageCompletable
}
+
+ if (Completion.IsFaulted)
+ {
+ _onInternalError(Completion.Exception.InnerException);
+ }
}
///
@@ -193,68 +218,63 @@ private async Task OnMessageAsyncImpl()
while (!Completion.IsCompleted)
{
ChannelMessage next;
- try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); }
- catch (ChannelClosedException) { break; } // expected
- catch (Exception ex)
+
+ // Keep trying to read values
+ while (!TryRead(out next))
{
- _parent.multiplexer?.OnInternalError(ex);
- break;
+ // If we fail, wait for an item to appear
+ if (!await _queue.WaitToReadAsync())
+ {
+ // Channel is closed
+ break;
+ }
+
+ // There should be an item available now, but another reader might grab it,
+ // so we keep TryReading in the loop.
}
try
{
var task = handler(next);
- if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ConfigureAwait(false);
+ if (task != null && task.Status != TaskStatus.RanToCompletion)
+ {
+ await task.ConfigureAwait(false);
+ }
}
catch { } // matches MessageCompletable
}
- }
- internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
- {
- var parent = _parent;
- _parent = null;
- if (parent != null)
+ if (Completion.IsFaulted)
{
- parent.UnsubscribeAsync(Channel, HandleMessage, flags);
+ _onInternalError(Completion.Exception.InnerException);
}
- _queue.Writer.TryComplete(error);
- }
-
- internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
- {
- var parent = _parent;
- _parent = null;
- if (parent != null)
- {
- await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ConfigureAwait(false);
- }
- _queue.Writer.TryComplete(error);
}
internal static bool IsOneOf(Action handler)
{
- try
- {
- return handler?.Target is ChannelMessageQueue
- && handler.Method.Name == nameof(HandleMessage);
- }
- catch
- {
- return false;
- }
+ // REVIEW: Need more context here to properly replace this.
+ throw new NotImplementedException();
+ //try
+ //{
+ // return handler?.Target is ChannelMessageQueue
+ // && handler.Method.Name == nameof(HandleMessage);
+ //}
+ //catch
+ //{
+ // return false;
+ //}
}
///
/// Stop receiving messages on this channel.
///
/// The flags to use when unsubscribing.
- public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags);
+ public void Unsubscribe(CommandFlags flags = CommandFlags.None) => _onUnsubscribe(flags);
///
/// Stop receiving messages on this channel.
///
/// The flags to use when unsubscribing.
- public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags);
+ public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => _onUnsubscribeAsync(flags);
}
}
diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs
index 25d30969d..b5fa562eb 100644
--- a/src/StackExchange.Redis/RedisSubscriber.cs
+++ b/src/StackExchange.Redis/RedisSubscriber.cs
@@ -4,6 +4,7 @@
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
+using System.Threading.Channels;
using System.Threading.Tasks;
namespace StackExchange.Redis
@@ -297,6 +298,13 @@ internal void OnTransactionLog(string message)
internal sealed class RedisSubscriber : RedisBase, ISubscriber
{
+ private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
+ {
+ SingleWriter = true,
+ SingleReader = false,
+ AllowSynchronousContinuations = false,
+ };
+
internal RedisSubscriber(ConnectionMultiplexer multiplexer, object asyncState) : base(multiplexer, asyncState)
{
}
@@ -376,22 +384,23 @@ public void Subscribe(RedisChannel channel, Action han
public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
- var c = new ChannelMessageQueue(channel, this);
- c.Subscribe(flags);
- return c;
+ var (queue, handler) = CreateMessageQueue(channel);
+ Subscribe(channel, handler, flags);
+ return queue;
}
public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None)
{
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
+
return multiplexer.AddSubscription(channel, handler, flags, asyncState);
}
public async Task SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
- var c = new ChannelMessageQueue(channel, this);
- await c.SubscribeAsync(flags).ForAwait();
- return c;
+ var (queue, handler) = CreateMessageQueue(channel);
+ await SubscribeAsync(channel, handler, flags).ForAwait();
+ return queue;
}
public EndPoint SubscribedEndpoint(RedisChannel channel)
@@ -420,7 +429,39 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None)
public Task UnsubscribeAsync(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None)
{
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
+
return multiplexer.RemoveSubscription(channel, handler, flags, asyncState);
}
+
+ internal (ChannelMessageQueue queue, Action handler) CreateMessageQueue(RedisChannel channel)
+ {
+ var queue = Channel.CreateUnbounded(s_ChannelOptions);
+
+ // We need to create the variable before constructing the queue so that the Handler function below can capture it.
+ // It'll be fully assigned before it's used since the delegate is invoked by the Unsubscribe instance method,
+ // which isn't available until the value is fully initialized.
+ ChannelMessageQueue messageQueue = null;
+ messageQueue = new ChannelMessageQueue(
+ channel,
+ queue.Reader,
+ (f) => Unsubscribe(channel, Handler, f),
+ (f) => UnsubscribeAsync(channel, Handler, f),
+ (ex) => multiplexer?.OnInternalError(ex));
+
+ void Handler(RedisChannel c, RedisValue v)
+ {
+ if (c.IsNull && v.IsNull)
+ {
+ queue.Writer.TryComplete();
+ }
+ else
+ {
+ var wrote = queue.Writer.TryWrite(new ChannelMessage(messageQueue, c, v));
+ Debug.Assert(wrote, "Queue should be unbounded!");
+ }
+ }
+
+ return (messageQueue, Handler);
+ }
}
}
diff --git a/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs b/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs
new file mode 100644
index 000000000..d2060cb5c
--- /dev/null
+++ b/tests/StackExchange.Redis.Tests/ChannelMessageQueueTests.cs
@@ -0,0 +1,220 @@
+using System;
+using System.Threading.Channels;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace StackExchange.Redis.Tests
+{
+ public class ChannelMessageQueueTests
+ {
+ [Fact]
+ public void ItemsYieldedToChannelCanBeRead()
+ {
+ var channel = Channel.CreateUnbounded();
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test")));
+ Assert.True(queue.TryRead(out var message));
+ Assert.Equal("TestChannel", message.Channel);
+ Assert.Equal("Test", message.Message);
+ Assert.Equal("TestChannel", message.SubscriptionChannel);
+ }
+
+ [Fact]
+ public void ActualChannelIsProvidedInChannelMessageChanelProperty()
+ {
+ var channel = Channel.CreateUnbounded();
+ var queue = new ChannelMessageQueue(
+ "TestChannel.*",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel.A", "Test")));
+ Assert.True(queue.TryRead(out var message));
+ Assert.Equal("TestChannel.A", message.Channel);
+ Assert.Equal("Test", message.Message);
+ Assert.Equal("TestChannel.*", message.SubscriptionChannel);
+ }
+
+ [Fact]
+ public void ReadAsyncYieldsItemWhenOneIsAvailable()
+ {
+ var channel = Channel.CreateUnbounded();
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ var readTask = queue.ReadAsync();
+
+ Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test")));
+ Assert.True(readTask.IsCompleted);
+ var message = readTask.GetAwaiter().GetResult();
+
+ Assert.Equal("TestChannel", message.Channel);
+ Assert.Equal("Test", message.Message);
+ Assert.Equal("TestChannel", message.SubscriptionChannel);
+ }
+
+ [Fact]
+ public void UnsubscribeSignalsProvidedDelegate()
+ {
+ var channel = Channel.CreateUnbounded();
+ CommandFlags? unsubscribeFlags = null;
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => unsubscribeFlags = f,
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ queue.Unsubscribe(CommandFlags.FireAndForget);
+
+ Assert.Equal(CommandFlags.FireAndForget, unsubscribeFlags);
+ }
+
+ [Fact]
+ public void UnsubscribeAsyncSignalsProvidedDelegate()
+ {
+ var channel = Channel.CreateUnbounded();
+ CommandFlags? unsubscribeFlags = null;
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) =>
+ {
+ unsubscribeFlags = f;
+ return Task.CompletedTask;
+ },
+ ex => throw new NotImplementedException());
+
+ Assert.True(queue.UnsubscribeAsync(CommandFlags.FireAndForget).IsCompleted);
+ Assert.Equal(CommandFlags.FireAndForget, unsubscribeFlags);
+ }
+
+ [Fact]
+ public async Task OnMessageCreatesMessageLoop()
+ {
+ var channel = Channel.CreateUnbounded();
+ var received = Channel.CreateUnbounded();
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ // Start the message loop
+ // NOTE: The loop runs on the thread pool
+ queue.OnMessage(m =>
+ {
+ Assert.True(received.Writer.TryWrite(m));
+ });
+
+ // Write an item and verify that it comes through the loop
+ Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test")));
+
+ var message = await received.Reader.ReadAsync();
+ Assert.Equal("TestChannel", message.Channel);
+ Assert.Equal("Test", message.Message);
+ Assert.Equal("TestChannel", message.SubscriptionChannel);
+
+ // Shut down the loop.
+ channel.Writer.TryComplete();
+ }
+
+ [Fact]
+ public async Task OnMessageAsyncCreatesMessageLoop()
+ {
+ var channel = Channel.CreateUnbounded();
+ var received = Channel.CreateUnbounded();
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => throw new NotImplementedException());
+
+ // Start the message loop
+ // NOTE: The loop runs on the thread pool
+ queue.OnMessage(m =>
+ {
+ Assert.True(received.Writer.TryWrite(m));
+ return Task.CompletedTask;
+ });
+
+ // Write an item and verify that it comes through the loop
+ Assert.True(channel.Writer.TryWrite(new ChannelMessage(queue, "TestChannel", "Test")));
+
+ var message = await received.Reader.ReadAsync();
+ Assert.Equal("TestChannel", message.Channel);
+ Assert.Equal("Test", message.Message);
+ Assert.Equal("TestChannel", message.SubscriptionChannel);
+
+ // Shut down the loop.
+ channel.Writer.TryComplete();
+ }
+
+ [Fact]
+ public async Task CompletionExceptionGoesToInternalErrorHandlerWhenUsingOnMessage()
+ {
+ var channel = Channel.CreateUnbounded();
+ var received = Channel.CreateUnbounded();
+ var errorOccurred = new TaskCompletionSource();
+
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => errorOccurred.TrySetResult(ex));
+
+ // Start the message loop
+ // NOTE: The loop runs on the thread pool
+ queue.OnMessage(m =>
+ {
+ throw new NotImplementedException();
+ });
+
+ // Complete with an error, triggering the exception
+ channel.Writer.TryComplete(new Exception("BARF!"));
+ Assert.Equal("BARF!", (await errorOccurred.Task).Message);
+ }
+
+ [Fact]
+ public async Task CompletionExceptionGoesToInternalErrorHandlerWhenUsingOnMessageAsync()
+ {
+ var channel = Channel.CreateUnbounded(new UnboundedChannelOptions());
+ var received = Channel.CreateUnbounded();
+ var errorOccurred = new TaskCompletionSource();
+
+ var queue = new ChannelMessageQueue(
+ "TestChannel",
+ channel.Reader,
+ (f) => throw new NotImplementedException(),
+ (f) => throw new NotImplementedException(),
+ ex => errorOccurred.TrySetResult(ex));
+
+ // Start the message loop
+ // NOTE: The loop runs on the thread pool
+ queue.OnMessage(m =>
+ {
+ return Task.FromException(new NotImplementedException());
+ });
+
+ // Complete with an error, triggering the exception
+ channel.Writer.TryComplete(new Exception("BARF!"));
+ Assert.Equal("BARF!", (await errorOccurred.Task).Message);
+ }
+ }
+}