Skip to content

WIP: #969 testable ChannelMessageQueue #1000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 99 additions & 79 deletions src/StackExchange.Redis/ChannelMessageQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public readonly struct ChannelMessage
/// <summary>
/// See Object.ToString
/// </summary>
public override string ToString() => ((string)Channel) + ":" + ((string)Message);
public override string ToString() => Channel + ":" + Message;

/// <summary>
/// See Object.GetHashCode
Expand All @@ -28,7 +28,14 @@ public readonly struct ChannelMessage
/// <param name="obj">The <see cref="object"/> to compare.</param>
public override bool Equals(object obj) => obj is ChannelMessage cm
&& cm.Channel == Channel && cm.Message == Message;
internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value)

/// <summary>
/// Create a new <see cref="ChannelMessage"/> representing a message written to a <see cref="ChannelMessageQueue"/>
/// </summary>
/// <param name="queue">The <see cref="ChannelMessageQueue"/> associated with this message.</param>
/// <param name="channel">A <see cref="RedisChannel"/> identifying the channel from which the message was received.</param>
/// <param name="value">A <see cref="RedisValue"/> representing the value of the message.</param>
public ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisValue value)
{
_queue = queue;
Channel = channel;
Expand Down Expand Up @@ -56,64 +63,59 @@ internal ChannelMessage(ChannelMessageQueue queue, RedisChannel channel, RedisVa
/// <remarks>To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey)</remarks>
public sealed class ChannelMessageQueue
{
private readonly Channel<ChannelMessage> _queue;
private readonly ChannelReader<ChannelMessage> _queue;
private readonly Action<CommandFlags> _onUnsubscribe;
private readonly Func<CommandFlags, Task> _onUnsubscribeAsync;
private readonly Action<Exception> _onInternalError;

/// <summary>
/// The Channel that was subscribed for this queue
/// </summary>
public RedisChannel Channel { get; }
private RedisSubscriber _parent;

/// <summary>
/// See Object.ToString
/// </summary>
public override string ToString() => (string)Channel;
public override string ToString() => Channel;

/// <summary>
/// An awaitable task the indicates completion of the queue (including drain of data)
/// </summary>
public Task Completion => _queue.Reader.Completion;
public Task Completion => _queue.Completion;

internal ChannelMessageQueue(RedisChannel redisChannel, RedisSubscriber parent)
/// <summary>
/// Constructs a <see cref="ChannelMessageQueue" /> from a <see cref="System.Threading.Channels.ChannelReader{RedisValue}"/> representing
/// incoming Redis values on the channel.
/// </summary>
/// <param name="channel">The name of the channel this subscription is listening on.</param>
/// <param name="incomingValues">A channel reader representing the incoming values.</param>
/// <param name="onUnsubscribe">A delegate to call when <see cref="Unsubscribe(CommandFlags)"/> is called.</param>
/// <param name="onUnsubscribeAsync">A delegate to call when <see cref="UnsubscribeAsync(CommandFlags)"/> is called.</param>
/// <param name="onInternalError">REVIEW: Need more context here</param>
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super clear on the purpose of this onInternalError callback (represents the previous call directly to ConnectionMultiplexer.OnInternalError). It was only called when a non-ChannelClosedException was thrown by ReadAsync in the message loop (which is highly unlikely).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was simply to do something in the case that an unexpected exception happened - agree "highly unlikely", but: at least in the "real" code, should do something; probably doesn't need to be mockable, if it presents a problem

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have two constructors. One that takes a RedisSubscriber (i.e. the internal one, that can access OnInternalError directly and one that just takes a ChannelReader and an action to fire when unsubscribe is called. Since this is primarily (exclusively?) for testing, it's less important that the testable entry point be able to perform all the necessary functions.

public ChannelMessageQueue(RedisChannel channel, ChannelReader<ChannelMessage> incomingValues, Action<CommandFlags> onUnsubscribe, Func<CommandFlags, Task> onUnsubscribeAsync, Action<Exception> onInternalError)
{
Channel = redisChannel;
_parent = parent;
_queue = System.Threading.Channels.Channel.CreateUnbounded<ChannelMessage>(s_ChannelOptions);
}
Channel = channel;
_queue = incomingValues;

private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
{
SingleWriter = true,
SingleReader = false,
AllowSynchronousContinuations = false,
};
internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags);
internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags);

private void HandleMessage(RedisChannel channel, RedisValue value)
{
var writer = _queue.Writer;
if (channel.IsNull && value.IsNull) // see ForSyncShutdown
{
writer.TryComplete();
}
else
{
writer.TryWrite(new ChannelMessage(this, channel, value));
}
// REVIEW: This part is kind of hacky...
_onUnsubscribe = onUnsubscribe;
_onUnsubscribeAsync = onUnsubscribeAsync;
_onInternalError = onInternalError;
}

/// <summary>
/// Consume a message from the channel.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to use.</param>
public ValueTask<ChannelMessage> ReadAsync(CancellationToken cancellationToken = default)
=> _queue.Reader.ReadAsync(cancellationToken);
=> _queue.ReadAsync(cancellationToken);

/// <summary>
/// Attempt to synchronously consume a message from the channel.
/// </summary>
/// <param name="item">The <see cref="ChannelMessage"/> read from the Channel.</param>
public bool TryRead(out ChannelMessage item) => _queue.Reader.TryRead(out item);
public bool TryRead(out ChannelMessage item)
=> _queue.TryRead(out item);

/// <summary>
/// Attempt to query the backlog length of the queue.
Expand All @@ -124,11 +126,16 @@ public bool TryGetCount(out int count)
// get this using the reflection
try
{
var prop = _queue.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic);
if (prop != null)
var parentField = _queue.GetType().GetField("_parent", BindingFlags.Instance | BindingFlags.NonPublic);
if (parentField != null)
{
count = (int)prop.GetValue(_queue);
return true;
var parent = parentField.GetValue(_queue);
var prop = parent.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic);
if (prop != null)
{
count = (int)prop.GetValue(parent);
return true;
}
}
}
catch { }
Expand All @@ -139,9 +146,15 @@ public bool TryGetCount(out int count)
private Delegate _onMessageHandler;
private void AssertOnMessage(Delegate handler)
{
if (handler == null) throw new ArgumentNullException(nameof(handler));
if (handler == null)
{
throw new ArgumentNullException(nameof(handler));
}

if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null)
{
throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed");
}
}

/// <summary>
Expand All @@ -162,17 +175,29 @@ private async Task OnMessageSyncImpl()
while (!Completion.IsCompleted)
{
ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)

// Keep trying to read values
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the pattern we generally use to read from a channel and gracefully stop without an exception. If there's a reason for the older behavior, I'm happy to reset it.

while (!TryRead(out next))
{
_parent.multiplexer?.OnInternalError(ex);
break;
// If we fail, wait for an item to appear
if (!await _queue.WaitToReadAsync())
{
// Channel is closed
break;
}

// There should be an item available now, but another reader might grab it,
// so we keep TryReading in the loop.
}

try { handler(next); }
catch { } // matches MessageCompletable
}

if (Completion.IsFaulted)
{
_onInternalError(Completion.Exception.InnerException);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically new behavior, it will trigger this callback if the channel is closed with an Exception, but we don't ever do that in the normal flow. A custom implementor that uses their own Channel in ChannelMessageQueue might though...

}
}

/// <summary>
Expand All @@ -193,68 +218,63 @@ private async Task OnMessageAsyncImpl()
while (!Completion.IsCompleted)
{
ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)

// Keep trying to read values
while (!TryRead(out next))
{
_parent.multiplexer?.OnInternalError(ex);
break;
// If we fail, wait for an item to appear
if (!await _queue.WaitToReadAsync())
{
// Channel is closed
break;
}

// There should be an item available now, but another reader might grab it,
// so we keep TryReading in the loop.
}

try
{
var task = handler(next);
if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ConfigureAwait(false);
if (task != null && task.Status != TaskStatus.RanToCompletion)
{
await task.ConfigureAwait(false);
}
}
catch { } // matches MessageCompletable
}
}

internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{
var parent = _parent;
_parent = null;
if (parent != null)
if (Completion.IsFaulted)
{
parent.UnsubscribeAsync(Channel, HandleMessage, flags);
_onInternalError(Completion.Exception.InnerException);
}
_queue.Writer.TryComplete(error);
}

internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{
var parent = _parent;
_parent = null;
if (parent != null)
{
await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ConfigureAwait(false);
}
_queue.Writer.TryComplete(error);
}

internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{
try
{
return handler?.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage);
}
catch
{
return false;
}
// REVIEW: Need more context here to properly replace this.
Copy link
Author

@analogrelay analogrelay Nov 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be easily restored by moving it to RedisSubscriber, since that's where the handler lives now. Just haven't done that yet.

throw new NotImplementedException();
//try
//{
// return handler?.Target is ChannelMessageQueue
// && handler.Method.Name == nameof(HandleMessage);
//}
//catch
//{
// return false;
//}
}

/// <summary>
/// Stop receiving messages on this channel.
/// </summary>
/// <param name="flags">The flags to use when unsubscribing.</param>
public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags);
public void Unsubscribe(CommandFlags flags = CommandFlags.None) => _onUnsubscribe(flags);

/// <summary>
/// Stop receiving messages on this channel.
/// </summary>
/// <param name="flags">The flags to use when unsubscribing.</param>
public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags);
public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => _onUnsubscribeAsync(flags);
}
}
53 changes: 47 additions & 6 deletions src/StackExchange.Redis/RedisSubscriber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace StackExchange.Redis
Expand Down Expand Up @@ -297,6 +298,13 @@ internal void OnTransactionLog(string message)

internal sealed class RedisSubscriber : RedisBase, ISubscriber
{
private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
{
SingleWriter = true,
SingleReader = false,
AllowSynchronousContinuations = false,
};

internal RedisSubscriber(ConnectionMultiplexer multiplexer, object asyncState) : base(multiplexer, asyncState)
{
}
Expand Down Expand Up @@ -376,22 +384,23 @@ public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> han

public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
var c = new ChannelMessageQueue(channel, this);
c.Subscribe(flags);
return c;
var (queue, handler) = CreateMessageQueue(channel);
Subscribe(channel, handler, flags);
return queue;
}

public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
{
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));

return multiplexer.AddSubscription(channel, handler, flags, asyncState);
}

public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
var c = new ChannelMessageQueue(channel, this);
await c.SubscribeAsync(flags).ForAwait();
return c;
var (queue, handler) = CreateMessageQueue(channel);
await SubscribeAsync(channel, handler, flags).ForAwait();
return queue;
}

public EndPoint SubscribedEndpoint(RedisChannel channel)
Expand Down Expand Up @@ -420,7 +429,39 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None)
public Task UnsubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
{
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));

return multiplexer.RemoveSubscription(channel, handler, flags, asyncState);
}

internal (ChannelMessageQueue queue, Action<RedisChannel, RedisValue> handler) CreateMessageQueue(RedisChannel channel)
{
var queue = Channel.CreateUnbounded<ChannelMessage>(s_ChannelOptions);

// We need to create the variable before constructing the queue so that the Handler function below can capture it.
// It'll be fully assigned before it's used since the delegate is invoked by the Unsubscribe instance method,
// which isn't available until the value is fully initialized.
ChannelMessageQueue messageQueue = null;
messageQueue = new ChannelMessageQueue(
channel,
queue.Reader,
(f) => Unsubscribe(channel, Handler, f),
(f) => UnsubscribeAsync(channel, Handler, f),
(ex) => multiplexer?.OnInternalError(ex));

void Handler(RedisChannel c, RedisValue v)
{
if (c.IsNull && v.IsNull)
{
queue.Writer.TryComplete();
}
else
{
var wrote = queue.Writer.TryWrite(new ChannelMessage(messageQueue, c, v));
Debug.Assert(wrote, "Queue should be unbounded!");
}
}

return (messageQueue, Handler);
}
}
}
Loading