diff --git a/eng/Versions.props b/eng/Versions.props
index 331bb3cfd761..ed9961676d06 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -199,7 +199,7 @@
0.3.0-alpha.19317.1
4.3.0
4.3.2
- 4.5.2
+ 4.5.3
1.10.0
5.2.6
@@ -242,7 +242,7 @@
3.0.0
3.0.0
3.0.0
- 1.7.3.7
+ 2.0.335
4.10.0
0.10.1
1.0.2
diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
index 78631d3c7f40..b42dcf7cfd34 100644
--- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
+++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
@@ -6,10 +6,11 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
+using System.Linq;
using System.Runtime.ExceptionServices;
-using System.Runtime.InteropServices;
using MessagePack;
using MessagePack.Formatters;
+using MessagePack.Resolvers;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.Options;
@@ -25,8 +26,7 @@ public class MessagePackHubProtocol : IHubProtocol
private const int VoidResult = 2;
private const int NonVoidResult = 3;
- private IFormatterResolver _resolver;
-
+ private MessagePackSerializerOptions _msgPackSerializerOptions;
private static readonly string ProtocolName = "messagepack";
private static readonly int ProtocolVersion = 1;
@@ -62,7 +62,9 @@ private void SetupResolver(MessagePackHubProtocolOptions options)
// with the provided resolvers
if (options.FormatterResolvers.Count != SignalRResolver.Resolvers.Count)
{
- _resolver = new CombinedResolvers(options.FormatterResolvers);
+ var resolver = CompositeResolver.Create(Array.Empty(), (IReadOnlyList)options.FormatterResolvers);
+ _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver);
+
return;
}
@@ -71,13 +73,14 @@ private void SetupResolver(MessagePackHubProtocolOptions options)
// check if the user customized the resolvers
if (options.FormatterResolvers[i] != SignalRResolver.Resolvers[i])
{
- _resolver = new CombinedResolvers(options.FormatterResolvers);
+ var resolver = CompositeResolver.Create(Array.Empty(), (IReadOnlyList)options.FormatterResolvers);
+ _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver);
return;
}
}
// Use optimized cached resolver if the default is chosen
- _resolver = SignalRResolver.Instance;
+ _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(SignalRResolver.Instance);
}
///
@@ -95,59 +98,43 @@ public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder
return false;
}
- var arraySegment = GetArraySegment(payload);
-
- message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder, _resolver);
+ var reader = new MessagePackReader(payload);
+ message = ParseMessage(ref reader, binder, _msgPackSerializerOptions);
return true;
}
- private static ArraySegment GetArraySegment(in ReadOnlySequence input)
- {
- if (input.IsSingleSegment)
- {
- var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment);
- // This will never be false unless we started using un-managed buffers
- Debug.Assert(isArray);
- return arraySegment;
- }
-
- // Should be rare
- return new ArraySegment(input.ToArray());
- }
-
- private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver)
+ private static HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
- var itemCount = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize);
- startOffset += readSize;
+ var itemCount = reader.ReadArrayHeader();
- var messageType = ReadInt32(input, ref startOffset, "messageType");
+ var messageType = ReadInt32(ref reader, "messageType");
switch (messageType)
{
case HubProtocolConstants.InvocationMessageType:
- return CreateInvocationMessage(input, ref startOffset, binder, resolver, itemCount);
+ return CreateInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
case HubProtocolConstants.StreamInvocationMessageType:
- return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver, itemCount);
+ return CreateStreamInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
case HubProtocolConstants.StreamItemMessageType:
- return CreateStreamItemMessage(input, ref startOffset, binder, resolver);
+ return CreateStreamItemMessage(ref reader, binder, msgPackSerializerOptions);
case HubProtocolConstants.CompletionMessageType:
- return CreateCompletionMessage(input, ref startOffset, binder, resolver);
+ return CreateCompletionMessage(ref reader, binder, msgPackSerializerOptions);
case HubProtocolConstants.CancelInvocationMessageType:
- return CreateCancelInvocationMessage(input, ref startOffset);
+ return CreateCancelInvocationMessage(ref reader);
case HubProtocolConstants.PingMessageType:
return PingMessage.Instance;
case HubProtocolConstants.CloseMessageType:
- return CreateCloseMessage(input, ref startOffset, itemCount);
+ return CreateCloseMessage(ref reader, itemCount);
default:
// Future protocol changes can add message types, old clients can ignore them
return null;
}
}
- private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount)
+ private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
{
- var headers = ReadHeaders(input, ref offset);
- var invocationId = ReadInvocationId(input, ref offset);
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
// For MsgPack, we represent an empty invocation ID as an empty string,
// so we need to normalize that to "null", which is what indicates a non-blocking invocation.
@@ -156,13 +143,13 @@ private static HubMessage CreateInvocationMessage(byte[] input, ref int offset,
invocationId = null;
}
- var target = ReadString(input, ref offset, "target");
+ var target = ReadString(ref reader, "target");
object[] arguments = null;
try
{
var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(input, ref offset, parameterTypes, resolver);
+ arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
}
catch (Exception ex)
{
@@ -173,23 +160,23 @@ private static HubMessage CreateInvocationMessage(byte[] input, ref int offset,
// Previous clients will send 5 items, so we check if they sent a stream array or not
if (itemCount > 5)
{
- streams = ReadStreamIds(input, ref offset);
+ streams = ReadStreamIds(ref reader);
}
return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
}
- private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount)
+ private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
{
- var headers = ReadHeaders(input, ref offset);
- var invocationId = ReadInvocationId(input, ref offset);
- var target = ReadString(input, ref offset, "target");
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ var target = ReadString(ref reader, "target");
object[] arguments = null;
try
{
var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(input, ref offset, parameterTypes, resolver);
+ arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
}
catch (Exception ex)
{
@@ -200,21 +187,21 @@ private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int of
// Previous clients will send 5 items, so we check if they sent a stream array or not
if (itemCount > 5)
{
- streams = ReadStreamIds(input, ref offset);
+ streams = ReadStreamIds(ref reader);
}
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
}
- private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
+ private static HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
- var headers = ReadHeaders(input, ref offset);
- var invocationId = ReadInvocationId(input, ref offset);
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
object value;
try
{
var itemType = binder.GetStreamItemType(invocationId);
- value = DeserializeObject(input, ref offset, itemType, "item", resolver);
+ value = DeserializeObject(ref reader, itemType, "item", msgPackSerializerOptions);
}
catch (Exception ex)
{
@@ -224,11 +211,11 @@ private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset,
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
}
- private static CompletionMessage CreateCompletionMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
+ private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
- var headers = ReadHeaders(input, ref offset);
- var invocationId = ReadInvocationId(input, ref offset);
- var resultKind = ReadInt32(input, ref offset, "resultKind");
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ var resultKind = ReadInt32(ref reader, "resultKind");
string error = null;
object result = null;
@@ -237,11 +224,11 @@ private static CompletionMessage CreateCompletionMessage(byte[] input, ref int o
switch (resultKind)
{
case ErrorResult:
- error = ReadString(input, ref offset, "error");
+ error = ReadString(ref reader, "error");
break;
case NonVoidResult:
var itemType = binder.GetReturnType(invocationId);
- result = DeserializeObject(input, ref offset, itemType, "argument", resolver);
+ result = DeserializeObject(ref reader, itemType, "argument", msgPackSerializerOptions);
hasResult = true;
break;
case VoidResult:
@@ -254,21 +241,21 @@ private static CompletionMessage CreateCompletionMessage(byte[] input, ref int o
return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
}
- private static CancelInvocationMessage CreateCancelInvocationMessage(byte[] input, ref int offset)
+ private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
{
- var headers = ReadHeaders(input, ref offset);
- var invocationId = ReadInvocationId(input, ref offset);
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
}
- private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int itemCount)
+ private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
{
- var error = ReadString(input, ref offset, "error");
+ var error = ReadString(ref reader, "error");
var allowReconnect = false;
if (itemCount > 2)
{
- allowReconnect = ReadBoolean(input, ref offset, "allowReconnect");
+ allowReconnect = ReadBoolean(ref reader, "allowReconnect");
}
// An empty string is still an error
@@ -280,17 +267,17 @@ private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int
return new CloseMessage(error, allowReconnect);
}
- private static Dictionary ReadHeaders(byte[] input, ref int offset)
+ private static Dictionary ReadHeaders(ref MessagePackReader reader)
{
- var headerCount = ReadMapLength(input, ref offset, "headers");
+ var headerCount = ReadMapLength(ref reader, "headers");
if (headerCount > 0)
{
var headers = new Dictionary(StringComparer.Ordinal);
for (var i = 0; i < headerCount; i++)
{
- var key = ReadString(input, ref offset, $"headers[{i}].Key");
- var value = ReadString(input, ref offset, $"headers[{i}].Value");
+ var key = ReadString(ref reader, $"headers[{i}].Key");
+ var value = ReadString(ref reader, $"headers[{i}].Value");
headers.Add(key, value);
}
return headers;
@@ -301,9 +288,9 @@ private static Dictionary ReadHeaders(byte[] input, ref int offs
}
}
- private static string[] ReadStreamIds(byte[] input, ref int offset)
+ private static string[] ReadStreamIds(ref MessagePackReader reader)
{
- var streamIdCount = ReadArrayLength(input, ref offset, "streamIds");
+ var streamIdCount = ReadArrayLength(ref reader, "streamIds");
List streams = null;
if (streamIdCount > 0)
@@ -311,17 +298,16 @@ private static string[] ReadStreamIds(byte[] input, ref int offset)
streams = new List();
for (var i = 0; i < streamIdCount; i++)
{
- streams.Add(MessagePackBinary.ReadString(input, offset, out var read));
- offset += read;
+ streams.Add(reader.ReadString());
}
}
return streams?.ToArray();
}
- private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList parameterTypes, IFormatterResolver resolver)
+ private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes, MessagePackSerializerOptions msgPackSerializerOptions)
{
- var argumentCount = ReadArrayLength(input, ref offset, "arguments");
+ var argumentCount = ReadArrayLength(ref reader, "arguments");
if (parameterTypes.Count != argumentCount)
{
@@ -334,7 +320,7 @@ private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyLis
var arguments = new object[argumentCount];
for (var i = 0; i < argumentCount; i++)
{
- arguments[i] = DeserializeObject(input, ref offset, parameterTypes[i], "argument", resolver);
+ arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument", msgPackSerializerOptions);
}
return arguments;
@@ -358,339 +344,314 @@ private static T ApplyHeaders(IDictionary source, T destinati
///
public void WriteMessage(HubMessage message, IBufferWriter output)
{
- var writer = MemoryBufferWriter.Get();
+ var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
+ var writer = new MessagePackWriter(memoryBufferWriter);
+
// Write message to a buffer so we can get its length
- WriteMessageCore(message, writer);
+ WriteMessageCore(message, ref writer);
// Write length then message to output
- BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output);
- writer.CopyTo(output);
+ BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output);
+ memoryBufferWriter.CopyTo(output);
}
finally
{
- MemoryBufferWriter.Return(writer);
+ MemoryBufferWriter.Return(memoryBufferWriter);
}
}
///
public ReadOnlyMemory GetMessageBytes(HubMessage message)
{
- var writer = MemoryBufferWriter.Get();
+ var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
+ var writer = new MessagePackWriter(memoryBufferWriter);
+
// Write message to a buffer so we can get its length
- WriteMessageCore(message, writer);
+ WriteMessageCore(message, ref writer);
- var dataLength = writer.Length;
- var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length);
+ var dataLength = memoryBufferWriter.Length;
+ var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length);
var array = new byte[dataLength + prefixLength];
var span = array.AsSpan();
// Write length then message to output
- var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span);
+ var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span);
Debug.Assert(written == prefixLength);
- writer.CopyTo(span.Slice(prefixLength));
+ memoryBufferWriter.CopyTo(span.Slice(prefixLength));
return array;
}
finally
{
- MemoryBufferWriter.Return(writer);
+ MemoryBufferWriter.Return(memoryBufferWriter);
}
}
- private void WriteMessageCore(HubMessage message, Stream packer)
+ private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
{
switch (message)
{
case InvocationMessage invocationMessage:
- WriteInvocationMessage(invocationMessage, packer);
+ WriteInvocationMessage(invocationMessage, ref writer);
break;
case StreamInvocationMessage streamInvocationMessage:
- WriteStreamInvocationMessage(streamInvocationMessage, packer);
+ WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
break;
case StreamItemMessage streamItemMessage:
- WriteStreamingItemMessage(streamItemMessage, packer);
+ WriteStreamingItemMessage(streamItemMessage, ref writer);
break;
case CompletionMessage completionMessage:
- WriteCompletionMessage(completionMessage, packer);
+ WriteCompletionMessage(completionMessage, ref writer);
break;
case CancelInvocationMessage cancelInvocationMessage:
- WriteCancelInvocationMessage(cancelInvocationMessage, packer);
+ WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
break;
case PingMessage pingMessage:
- WritePingMessage(pingMessage, packer);
+ WritePingMessage(pingMessage, ref writer);
break;
case CloseMessage closeMessage:
- WriteCloseMessage(closeMessage, packer);
+ WriteCloseMessage(closeMessage, ref writer);
break;
default:
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
}
+
+ writer.Flush();
}
- private void WriteInvocationMessage(InvocationMessage message, Stream packer)
+ private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 6);
+ writer.WriteArrayHeader(6);
- MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType);
- PackHeaders(packer, message.Headers);
+ writer.Write(HubProtocolConstants.InvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
if (string.IsNullOrEmpty(message.InvocationId))
{
- MessagePackBinary.WriteNil(packer);
+ writer.WriteNil();
}
else
{
- MessagePackBinary.WriteString(packer, message.InvocationId);
+ writer.Write(message.InvocationId);
}
- MessagePackBinary.WriteString(packer, message.Target);
- MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length);
+ writer.Write(message.Target);
+ writer.WriteArrayHeader(message.Arguments.Length);
foreach (var arg in message.Arguments)
{
- WriteArgument(arg, packer);
+ WriteArgument(arg, ref writer);
}
- WriteStreamIds(message.StreamIds, packer);
+ WriteStreamIds(message.StreamIds, ref writer);
}
- private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer)
+ private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 6);
+ writer.WriteArrayHeader(6);
- MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType);
- PackHeaders(packer, message.Headers);
- MessagePackBinary.WriteString(packer, message.InvocationId);
- MessagePackBinary.WriteString(packer, message.Target);
+ writer.Write(HubProtocolConstants.StreamInvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ writer.Write(message.Target);
- MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length);
+ writer.WriteArrayHeader(message.Arguments.Length);
foreach (var arg in message.Arguments)
{
- WriteArgument(arg, packer);
+ WriteArgument(arg, ref writer);
}
- WriteStreamIds(message.StreamIds, packer);
+ WriteStreamIds(message.StreamIds, ref writer);
}
- private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer)
+ private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 4);
- MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamItemMessageType);
- PackHeaders(packer, message.Headers);
- MessagePackBinary.WriteString(packer, message.InvocationId);
- WriteArgument(message.Item, packer);
+ writer.WriteArrayHeader(4);
+ writer.Write(HubProtocolConstants.StreamItemMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ WriteArgument(message.Item, ref writer);
}
- private void WriteArgument(object argument, Stream stream)
+ private void WriteArgument(object argument, ref MessagePackWriter writer)
{
if (argument == null)
{
- MessagePackBinary.WriteNil(stream);
+ writer.WriteNil();
}
else
{
- MessagePackSerializer.NonGeneric.Serialize(argument.GetType(), stream, argument, _resolver);
+ MessagePackSerializer.Serialize(argument.GetType(), ref writer, argument, _msgPackSerializerOptions);
}
}
- private void WriteStreamIds(string[] streamIds, Stream packer)
+ private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer)
{
if (streamIds != null)
{
- MessagePackBinary.WriteArrayHeader(packer, streamIds.Length);
+ writer.WriteArrayHeader(streamIds.Length);
foreach (var streamId in streamIds)
{
- MessagePackBinary.WriteString(packer, streamId);
+ writer.Write(streamId);
}
}
else
{
- MessagePackBinary.WriteArrayHeader(packer, 0);
+ writer.WriteArrayHeader(0);
}
}
- private void WriteCompletionMessage(CompletionMessage message, Stream packer)
+ private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
{
var resultKind =
message.Error != null ? ErrorResult :
message.HasResult ? NonVoidResult :
VoidResult;
- MessagePackBinary.WriteArrayHeader(packer, 4 + (resultKind != VoidResult ? 1 : 0));
- MessagePackBinary.WriteInt32(packer, HubProtocolConstants.CompletionMessageType);
- PackHeaders(packer, message.Headers);
- MessagePackBinary.WriteString(packer, message.InvocationId);
- MessagePackBinary.WriteInt32(packer, resultKind);
+ writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
+ writer.Write(HubProtocolConstants.CompletionMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ writer.Write(resultKind);
switch (resultKind)
{
case ErrorResult:
- MessagePackBinary.WriteString(packer, message.Error);
+ writer.Write(message.Error);
break;
case NonVoidResult:
- WriteArgument(message.Result, packer);
+ WriteArgument(message.Result, ref writer);
break;
}
}
- private void WriteCancelInvocationMessage(CancelInvocationMessage message, Stream packer)
+ private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 3);
- MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CancelInvocationMessageType);
- PackHeaders(packer, message.Headers);
- MessagePackBinary.WriteString(packer, message.InvocationId);
+ writer.WriteArrayHeader(3);
+ writer.Write(HubProtocolConstants.CancelInvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
}
- private void WriteCloseMessage(CloseMessage message, Stream packer)
+ private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 3);
- MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CloseMessageType);
+ writer.WriteArrayHeader(3);
+ writer.Write(HubProtocolConstants.CloseMessageType);
if (string.IsNullOrEmpty(message.Error))
{
- MessagePackBinary.WriteNil(packer);
+ writer.WriteNil();
}
else
{
- MessagePackBinary.WriteString(packer, message.Error);
+ writer.Write(message.Error);
}
- MessagePackBinary.WriteBoolean(packer, message.AllowReconnect);
+ writer.Write(message.AllowReconnect);
}
- private void WritePingMessage(PingMessage pingMessage, Stream packer)
+ private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer)
{
- MessagePackBinary.WriteArrayHeader(packer, 1);
- MessagePackBinary.WriteInt32(packer, HubProtocolConstants.PingMessageType);
+ writer.WriteArrayHeader(1);
+ writer.Write(HubProtocolConstants.PingMessageType);
}
- private void PackHeaders(Stream packer, IDictionary headers)
+ private void PackHeaders(IDictionary headers, ref MessagePackWriter writer)
{
if (headers != null)
{
- MessagePackBinary.WriteMapHeader(packer, headers.Count);
+ writer.WriteMapHeader(headers.Count);
if (headers.Count > 0)
{
foreach (var header in headers)
{
- MessagePackBinary.WriteString(packer, header.Key);
- MessagePackBinary.WriteString(packer, header.Value);
+ writer.Write(header.Key);
+ writer.Write(header.Value);
}
}
}
else
{
- MessagePackBinary.WriteMapHeader(packer, 0);
+ writer.WriteMapHeader(0);
}
}
- private static string ReadInvocationId(byte[] input, ref int offset)
- {
- return ReadString(input, ref offset, "invocationId");
- }
+ private static string ReadInvocationId(ref MessagePackReader reader) =>
+ ReadString(ref reader, "invocationId");
- private static bool ReadBoolean(byte[] input, ref int offset, string field)
+ private static bool ReadBoolean(ref MessagePackReader reader, string field)
{
- Exception msgPackException = null;
try
{
- var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize);
- offset += readSize;
- return readBool;
+ return reader.ReadBoolean();
}
- catch (Exception e)
+ catch (Exception ex)
{
- msgPackException = e;
+ throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
}
-
- throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException);
}
- private static int ReadInt32(byte[] input, ref int offset, string field)
+ private static int ReadInt32(ref MessagePackReader reader, string field)
{
- Exception msgPackException = null;
try
{
- var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize);
- offset += readSize;
- return readInt;
+ return reader.ReadInt32();
}
- catch (Exception e)
+ catch (Exception ex)
{
- msgPackException = e;
+ throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
}
-
- throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException);
}
- private static string ReadString(byte[] input, ref int offset, string field)
+ private static string ReadString(ref MessagePackReader reader, string field)
{
- Exception msgPackException = null;
try
{
- var readString = MessagePackBinary.ReadString(input, offset, out var readSize);
- offset += readSize;
- return readString;
+ return reader.ReadString();
}
- catch (Exception e)
+ catch (Exception ex)
{
- msgPackException = e;
+ throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
}
-
- throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException);
}
- private static long ReadMapLength(byte[] input, ref int offset, string field)
+ private static long ReadMapLength(ref MessagePackReader reader, string field)
{
- Exception msgPackException = null;
try
{
- var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize);
- offset += readSize;
- return readMap;
+ return reader.ReadMapHeader();
}
- catch (Exception e)
+ catch (Exception ex)
{
- msgPackException = e;
+ throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
}
- throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException);
}
- private static long ReadArrayLength(byte[] input, ref int offset, string field)
+ private static long ReadArrayLength(ref MessagePackReader reader, string field)
{
- Exception msgPackException = null;
try
{
- var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize);
- offset += readSize;
- return readArray;
+ return reader.ReadArrayHeader();
}
- catch (Exception e)
+ catch (Exception ex)
{
- msgPackException = e;
+ throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
}
-
- throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException);
}
- private static object DeserializeObject(byte[] input, ref int offset, Type type, string field, IFormatterResolver resolver)
+ private static object DeserializeObject(ref MessagePackReader reader, Type type, string field, MessagePackSerializerOptions msgPackSerializerOptions)
{
- Exception msgPackException = null;
try
{
- var obj = MessagePackSerializer.NonGeneric.Deserialize(type, new ArraySegment(input, offset, input.Length - offset), resolver);
- offset += MessagePackBinary.ReadNextBlock(input, offset);
- return obj;
+ return MessagePackSerializer.Deserialize(type, ref reader, msgPackSerializerOptions);
}
catch (Exception ex)
{
- msgPackException = ex;
+ throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
}
-
- throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException);
}
internal static List CreateDefaultFormatterResolvers()
@@ -703,10 +664,10 @@ internal class SignalRResolver : IFormatterResolver
{
public static readonly IFormatterResolver Instance = new SignalRResolver();
- public static readonly IList Resolvers = new[]
+ public static readonly IList Resolvers = new IFormatterResolver[]
{
- MessagePack.Resolvers.DynamicEnumAsStringResolver.Instance,
- MessagePack.Resolvers.ContractlessStandardResolver.Instance,
+ DynamicEnumAsStringResolver.Instance,
+ ContractlessStandardResolver.Instance,
};
public IMessagePackFormatter GetFormatter()
@@ -731,30 +692,5 @@ static Cache()
}
}
}
-
- // Support for users making their own Formatter lists
- internal class CombinedResolvers : IFormatterResolver
- {
- private readonly IList _resolvers;
-
- public CombinedResolvers(IList resolvers)
- {
- _resolvers = resolvers;
- }
-
- public IMessagePackFormatter GetFormatter()
- {
- foreach (var resolver in _resolvers)
- {
- var formatter = resolver.GetFormatter();
- if (formatter != null)
- {
- return formatter;
- }
- }
-
- return null;
- }
- }
}
}
diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs
index 8a87a67bd695..2355228bd9c0 100644
--- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs
+++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs
@@ -26,15 +26,20 @@ public void SerializerCanSerializeTypesWithNoDefaultCtor()
AssertMessages(new byte[] { ArrayBytes(5), 3, 0x80, StringBytes(1), (byte)'0', 0x03, ArrayBytes(1), 42 }, result);
}
- [Fact]
- public void WriteAndParseDateTimeConvertsToUTC()
+ [Theory]
+ [InlineData(DateTimeKind.Utc)]
+ [InlineData(DateTimeKind.Local)]
+ [InlineData(DateTimeKind.Unspecified)]
+ public void WriteAndParseDateTimeConvertsToUTC(DateTimeKind dateTimeKind)
{
- var dateTime = new DateTime(2018, 4, 9);
+ // The messagepack Timestamp format always converts input DateTime to Utc if they are passed as "DateTimeKind.Local" :
+ // https://github.com/neuecc/MessagePack-CSharp/pull/520/files#diff-ed970b3daebc708ce49f55d418075979
+ var originalDateTime = new DateTime(2018, 4, 9, 0, 0, 0, dateTimeKind);
var writer = MemoryBufferWriter.Get();
try
{
- HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", dateTime), writer);
+ HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", originalDateTime), writer);
var bytes = new ReadOnlySequence(writer.ToArray());
HubProtocol.TryParseMessage(ref bytes, new TestBinder(typeof(DateTime)), out var hubMessage);
@@ -44,7 +49,10 @@ public void WriteAndParseDateTimeConvertsToUTC()
// The messagepack Timestamp format specifies that time is stored as seconds since 1970-01-01 UTC
// so the library has no choice but to store the time as UTC
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
- Assert.Equal(dateTime.ToUniversalTime(), resultDateTime);
+ // So If the original DateTiem was a "Local" one, we create a new DateTime equivalent to the original one but converted to Utc
+ var expectedUtcDateTime = (originalDateTime.Kind == DateTimeKind.Local) ? originalDateTime.ToUniversalTime() : originalDateTime;
+
+ Assert.Equal(expectedUtcDateTime, resultDateTime);
}
finally
{
diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
index 3d9c97c1ac19..4f2e9e3cd34c 100644
--- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
+++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
@@ -2512,33 +2512,15 @@ public IMessagePackFormatter GetFormatter()
private class StringFormatter : IMessagePackFormatter
{
- public T Deserialize(byte[] bytes, int offset, IFormatterResolver formatterResolver, out int readSize)
+ public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options)
{
// this method isn't used in our tests
- readSize = 0;
return default;
}
- public int Serialize(ref byte[] bytes, int offset, T value, IFormatterResolver formatterResolver)
- {
- // string of size 15
- bytes[offset] = 0xAF;
- bytes[offset + 1] = (byte)'f';
- bytes[offset + 2] = (byte)'o';
- bytes[offset + 3] = (byte)'r';
- bytes[offset + 4] = (byte)'m';
- bytes[offset + 5] = (byte)'a';
- bytes[offset + 6] = (byte)'t';
- bytes[offset + 7] = (byte)'t';
- bytes[offset + 8] = (byte)'e';
- bytes[offset + 9] = (byte)'d';
- bytes[offset + 10] = (byte)'S';
- bytes[offset + 11] = (byte)'t';
- bytes[offset + 12] = (byte)'r';
- bytes[offset + 13] = (byte)'i';
- bytes[offset + 14] = (byte)'n';
- bytes[offset + 15] = (byte)'g';
- return 16;
+ public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializerOptions options)
+ {
+ writer.Write("formattedString");
}
}
}
diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs
deleted file mode 100644
index 7780bca98850..000000000000
--- a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.Diagnostics;
-using System.Runtime.InteropServices;
-using MessagePack;
-
-namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
-{
- internal static class MessagePackUtil
- {
- public static int ReadArrayHeader(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- public static int ReadMapHeader(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- public static string ReadString(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- public static byte[] ReadBytes(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- public static int ReadInt32(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- public static byte ReadByte(ref ReadOnlyMemory data)
- {
- var arr = GetArray(data);
- var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize);
- data = data.Slice(readSize);
- return val;
- }
-
- private static ArraySegment GetArray(ReadOnlyMemory data)
- {
- var isArray = MemoryMarshal.TryGetArray(data, out var array);
- Debug.Assert(isArray);
- return array;
- }
- }
-}
diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs
index 24426184895c..3126b52f5d49 100644
--- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs
+++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
@@ -43,30 +44,33 @@ public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList 0)
{
- MessagePackBinary.WriteArrayHeader(writer, excludedConnectionIds.Count);
+ writer.WriteArrayHeader(excludedConnectionIds.Count);
foreach (var id in excludedConnectionIds)
{
- MessagePackBinary.WriteString(writer, id);
+ writer.Write(id);
}
}
else
{
- MessagePackBinary.WriteArrayHeader(writer, 0);
+ writer.WriteArrayHeader(0);
}
- WriteHubMessage(writer, new InvocationMessage(methodName, args));
- return writer.ToArray();
+ WriteHubMessage(ref writer, new InvocationMessage(methodName, args));
+ writer.Flush();
+
+ return memoryBufferWriter.ToArray();
}
finally
{
- MemoryBufferWriter.Return(writer);
+ MemoryBufferWriter.Return(memoryBufferWriter);
}
}
@@ -80,21 +84,24 @@ public byte[] WriteGroupCommand(RedisGroupCommand command)
// * A 'str': The connection Id
// Any additional items are discarded.
- var writer = MemoryBufferWriter.Get();
+ var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
- MessagePackBinary.WriteArrayHeader(writer, 5);
- MessagePackBinary.WriteInt32(writer, command.Id);
- MessagePackBinary.WriteString(writer, command.ServerName);
- MessagePackBinary.WriteByte(writer, (byte)command.Action);
- MessagePackBinary.WriteString(writer, command.GroupName);
- MessagePackBinary.WriteString(writer, command.ConnectionId);
-
- return writer.ToArray();
+ var writer = new MessagePackWriter(memoryBufferWriter);
+
+ writer.WriteArrayHeader(5);
+ writer.Write(command.Id);
+ writer.Write(command.ServerName);
+ writer.Write((byte)command.Action);
+ writer.Write(command.GroupName);
+ writer.Write(command.ConnectionId);
+ writer.Flush();
+
+ return memoryBufferWriter.ToArray();
}
finally
{
- MemoryBufferWriter.Return(writer);
+ MemoryBufferWriter.Return(memoryBufferWriter);
}
}
@@ -104,101 +111,110 @@ public byte[] WriteAck(int messageId)
// * An 'int': The Id of the command being acknowledged.
// Any additional items are discarded.
- var writer = MemoryBufferWriter.Get();
+ var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
- MessagePackBinary.WriteArrayHeader(writer, 1);
- MessagePackBinary.WriteInt32(writer, messageId);
+ var writer = new MessagePackWriter(memoryBufferWriter);
- return writer.ToArray();
+ writer.WriteArrayHeader(1);
+ writer.Write(messageId);
+ writer.Flush();
+
+ return memoryBufferWriter.ToArray();
}
finally
{
- MemoryBufferWriter.Return(writer);
+ MemoryBufferWriter.Return(memoryBufferWriter);
}
}
public RedisInvocation ReadInvocation(ReadOnlyMemory data)
{
// See WriteInvocation for the format
- ValidateArraySize(ref data, 2, "Invocation");
+ var reader = new MessagePackReader(data);
+ ValidateArraySize(ref reader, 2, "Invocation");
// Read excluded Ids
IReadOnlyList excludedConnectionIds = null;
- var idCount = MessagePackUtil.ReadArrayHeader(ref data);
+ var idCount = reader.ReadArrayHeader();
if (idCount > 0)
{
var ids = new string[idCount];
for (var i = 0; i < idCount; i++)
{
- ids[i] = MessagePackUtil.ReadString(ref data);
+ ids[i] = reader.ReadString();
}
excludedConnectionIds = ids;
}
// Read payload
- var message = ReadSerializedHubMessage(ref data);
+ var message = ReadSerializedHubMessage(ref reader);
return new RedisInvocation(message, excludedConnectionIds);
}
public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data)
{
+ var reader = new MessagePackReader(data);
+
// See WriteGroupCommand for format.
- ValidateArraySize(ref data, 5, "GroupCommand");
+ ValidateArraySize(ref reader, 5, "GroupCommand");
- var id = MessagePackUtil.ReadInt32(ref data);
- var serverName = MessagePackUtil.ReadString(ref data);
- var action = (GroupAction)MessagePackUtil.ReadByte(ref data);
- var groupName = MessagePackUtil.ReadString(ref data);
- var connectionId = MessagePackUtil.ReadString(ref data);
+ var id = reader.ReadInt32();
+ var serverName = reader.ReadString();
+ var action = (GroupAction)reader.ReadByte();
+ var groupName = reader.ReadString();
+ var connectionId = reader.ReadString();
return new RedisGroupCommand(id, serverName, action, groupName, connectionId);
}
public int ReadAck(ReadOnlyMemory data)
{
+ var reader = new MessagePackReader(data);
+
// See WriteAck for format
- ValidateArraySize(ref data, 1, "Ack");
- return MessagePackUtil.ReadInt32(ref data);
+ ValidateArraySize(ref reader, 1, "Ack");
+ return reader.ReadInt32();
}
- private void WriteHubMessage(Stream stream, HubMessage message)
+ private void WriteHubMessage(ref MessagePackWriter writer, HubMessage message)
{
// Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str')
// and the values are the serialized blob (as a MessagePack 'bin').
var serializedHubMessages = _messageSerializer.SerializeMessage(message);
- MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count);
+ writer.WriteMapHeader(serializedHubMessages.Count);
foreach (var serializedMessage in serializedHubMessages)
{
- MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName);
+ writer.Write(serializedMessage.ProtocolName);
var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array);
Debug.Assert(isArray);
- MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count);
+ writer.Write(array);
}
}
- public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory data)
+ public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReader reader)
{
- var count = MessagePackUtil.ReadMapHeader(ref data);
+ var count = reader.ReadMapHeader();
var serializations = new SerializedMessage[count];
for (var i = 0; i < count; i++)
{
- var protocol = MessagePackUtil.ReadString(ref data);
- var serialized = MessagePackUtil.ReadBytes(ref data);
+ var protocol = reader.ReadString();
+ var serialized = reader.ReadBytes()?.ToArray() ?? Array.Empty();
+
serializations[i] = new SerializedMessage(protocol, serialized);
}
return new SerializedHubMessage(serializations);
}
- private static void ValidateArraySize(ref ReadOnlyMemory data, int expectedLength, string messageType)
+ private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType)
{
- var length = MessagePackUtil.ReadArrayHeader(ref data);
+ var length = reader.ReadArrayHeader();
if (length < expectedLength)
{