diff --git a/eng/Versions.props b/eng/Versions.props index 331bb3cfd761..ed9961676d06 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -199,7 +199,7 @@ 0.3.0-alpha.19317.1 4.3.0 4.3.2 - 4.5.2 + 4.5.3 1.10.0 5.2.6 @@ -242,7 +242,7 @@ 3.0.0 3.0.0 3.0.0 - 1.7.3.7 + 2.0.335 4.10.0 0.10.1 1.0.2 diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs index 78631d3c7f40..b42dcf7cfd34 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs @@ -6,10 +6,11 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices; using MessagePack; using MessagePack.Formatters; +using MessagePack.Resolvers; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.Extensions.Options; @@ -25,8 +26,7 @@ public class MessagePackHubProtocol : IHubProtocol private const int VoidResult = 2; private const int NonVoidResult = 3; - private IFormatterResolver _resolver; - + private MessagePackSerializerOptions _msgPackSerializerOptions; private static readonly string ProtocolName = "messagepack"; private static readonly int ProtocolVersion = 1; @@ -62,7 +62,9 @@ private void SetupResolver(MessagePackHubProtocolOptions options) // with the provided resolvers if (options.FormatterResolvers.Count != SignalRResolver.Resolvers.Count) { - _resolver = new CombinedResolvers(options.FormatterResolvers); + var resolver = CompositeResolver.Create(Array.Empty(), (IReadOnlyList)options.FormatterResolvers); + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver); + return; } @@ -71,13 +73,14 @@ private void SetupResolver(MessagePackHubProtocolOptions options) // check if the user customized the resolvers if (options.FormatterResolvers[i] != SignalRResolver.Resolvers[i]) { - _resolver = new CombinedResolvers(options.FormatterResolvers); + var resolver = CompositeResolver.Create(Array.Empty(), (IReadOnlyList)options.FormatterResolvers); + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver); return; } } // Use optimized cached resolver if the default is chosen - _resolver = SignalRResolver.Instance; + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(SignalRResolver.Instance); } /// @@ -95,59 +98,43 @@ public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder return false; } - var arraySegment = GetArraySegment(payload); - - message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder, _resolver); + var reader = new MessagePackReader(payload); + message = ParseMessage(ref reader, binder, _msgPackSerializerOptions); return true; } - private static ArraySegment GetArraySegment(in ReadOnlySequence input) - { - if (input.IsSingleSegment) - { - var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment); - // This will never be false unless we started using un-managed buffers - Debug.Assert(isArray); - return arraySegment; - } - - // Should be rare - return new ArraySegment(input.ToArray()); - } - - private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var itemCount = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize); - startOffset += readSize; + var itemCount = reader.ReadArrayHeader(); - var messageType = ReadInt32(input, ref startOffset, "messageType"); + var messageType = ReadInt32(ref reader, "messageType"); switch (messageType) { case HubProtocolConstants.InvocationMessageType: - return CreateInvocationMessage(input, ref startOffset, binder, resolver, itemCount); + return CreateInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount); case HubProtocolConstants.StreamInvocationMessageType: - return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver, itemCount); + return CreateStreamInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount); case HubProtocolConstants.StreamItemMessageType: - return CreateStreamItemMessage(input, ref startOffset, binder, resolver); + return CreateStreamItemMessage(ref reader, binder, msgPackSerializerOptions); case HubProtocolConstants.CompletionMessageType: - return CreateCompletionMessage(input, ref startOffset, binder, resolver); + return CreateCompletionMessage(ref reader, binder, msgPackSerializerOptions); case HubProtocolConstants.CancelInvocationMessageType: - return CreateCancelInvocationMessage(input, ref startOffset); + return CreateCancelInvocationMessage(ref reader); case HubProtocolConstants.PingMessageType: return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: - return CreateCloseMessage(input, ref startOffset, itemCount); + return CreateCloseMessage(ref reader, itemCount); default: // Future protocol changes can add message types, old clients can ignore them return null; } } - private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) + private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); // For MsgPack, we represent an empty invocation ID as an empty string, // so we need to normalize that to "null", which is what indicates a non-blocking invocation. @@ -156,13 +143,13 @@ private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, invocationId = null; } - var target = ReadString(input, ref offset, "target"); + var target = ReadString(ref reader, "target"); object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - arguments = BindArguments(input, ref offset, parameterTypes, resolver); + arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions); } catch (Exception ex) { @@ -173,23 +160,23 @@ private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, // Previous clients will send 5 items, so we check if they sent a stream array or not if (itemCount > 5) { - streams = ReadStreamIds(input, ref offset); + streams = ReadStreamIds(ref reader); } return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams)); } - private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) + private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); - var target = ReadString(input, ref offset, "target"); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + var target = ReadString(ref reader, "target"); object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - arguments = BindArguments(input, ref offset, parameterTypes, resolver); + arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions); } catch (Exception ex) { @@ -200,21 +187,21 @@ private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int of // Previous clients will send 5 items, so we check if they sent a stream array or not if (itemCount > 5) { - streams = ReadStreamIds(input, ref offset); + streams = ReadStreamIds(ref reader); } return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams)); } - private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); object value; try { var itemType = binder.GetStreamItemType(invocationId); - value = DeserializeObject(input, ref offset, itemType, "item", resolver); + value = DeserializeObject(ref reader, itemType, "item", msgPackSerializerOptions); } catch (Exception ex) { @@ -224,11 +211,11 @@ private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, return ApplyHeaders(headers, new StreamItemMessage(invocationId, value)); } - private static CompletionMessage CreateCompletionMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); - var resultKind = ReadInt32(input, ref offset, "resultKind"); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + var resultKind = ReadInt32(ref reader, "resultKind"); string error = null; object result = null; @@ -237,11 +224,11 @@ private static CompletionMessage CreateCompletionMessage(byte[] input, ref int o switch (resultKind) { case ErrorResult: - error = ReadString(input, ref offset, "error"); + error = ReadString(ref reader, "error"); break; case NonVoidResult: var itemType = binder.GetReturnType(invocationId); - result = DeserializeObject(input, ref offset, itemType, "argument", resolver); + result = DeserializeObject(ref reader, itemType, "argument", msgPackSerializerOptions); hasResult = true; break; case VoidResult: @@ -254,21 +241,21 @@ private static CompletionMessage CreateCompletionMessage(byte[] input, ref int o return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult)); } - private static CancelInvocationMessage CreateCancelInvocationMessage(byte[] input, ref int offset) + private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); return ApplyHeaders(headers, new CancelInvocationMessage(invocationId)); } - private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int itemCount) + private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount) { - var error = ReadString(input, ref offset, "error"); + var error = ReadString(ref reader, "error"); var allowReconnect = false; if (itemCount > 2) { - allowReconnect = ReadBoolean(input, ref offset, "allowReconnect"); + allowReconnect = ReadBoolean(ref reader, "allowReconnect"); } // An empty string is still an error @@ -280,17 +267,17 @@ private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int return new CloseMessage(error, allowReconnect); } - private static Dictionary ReadHeaders(byte[] input, ref int offset) + private static Dictionary ReadHeaders(ref MessagePackReader reader) { - var headerCount = ReadMapLength(input, ref offset, "headers"); + var headerCount = ReadMapLength(ref reader, "headers"); if (headerCount > 0) { var headers = new Dictionary(StringComparer.Ordinal); for (var i = 0; i < headerCount; i++) { - var key = ReadString(input, ref offset, $"headers[{i}].Key"); - var value = ReadString(input, ref offset, $"headers[{i}].Value"); + var key = ReadString(ref reader, $"headers[{i}].Key"); + var value = ReadString(ref reader, $"headers[{i}].Value"); headers.Add(key, value); } return headers; @@ -301,9 +288,9 @@ private static Dictionary ReadHeaders(byte[] input, ref int offs } } - private static string[] ReadStreamIds(byte[] input, ref int offset) + private static string[] ReadStreamIds(ref MessagePackReader reader) { - var streamIdCount = ReadArrayLength(input, ref offset, "streamIds"); + var streamIdCount = ReadArrayLength(ref reader, "streamIds"); List streams = null; if (streamIdCount > 0) @@ -311,17 +298,16 @@ private static string[] ReadStreamIds(byte[] input, ref int offset) streams = new List(); for (var i = 0; i < streamIdCount; i++) { - streams.Add(MessagePackBinary.ReadString(input, offset, out var read)); - offset += read; + streams.Add(reader.ReadString()); } } return streams?.ToArray(); } - private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList parameterTypes, IFormatterResolver resolver) + private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes, MessagePackSerializerOptions msgPackSerializerOptions) { - var argumentCount = ReadArrayLength(input, ref offset, "arguments"); + var argumentCount = ReadArrayLength(ref reader, "arguments"); if (parameterTypes.Count != argumentCount) { @@ -334,7 +320,7 @@ private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyLis var arguments = new object[argumentCount]; for (var i = 0; i < argumentCount; i++) { - arguments[i] = DeserializeObject(input, ref offset, parameterTypes[i], "argument", resolver); + arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument", msgPackSerializerOptions); } return arguments; @@ -358,339 +344,314 @@ private static T ApplyHeaders(IDictionary source, T destinati /// public void WriteMessage(HubMessage message, IBufferWriter output) { - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { + var writer = new MessagePackWriter(memoryBufferWriter); + // Write message to a buffer so we can get its length - WriteMessageCore(message, writer); + WriteMessageCore(message, ref writer); // Write length then message to output - BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output); - writer.CopyTo(output); + BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output); + memoryBufferWriter.CopyTo(output); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } /// public ReadOnlyMemory GetMessageBytes(HubMessage message) { - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { + var writer = new MessagePackWriter(memoryBufferWriter); + // Write message to a buffer so we can get its length - WriteMessageCore(message, writer); + WriteMessageCore(message, ref writer); - var dataLength = writer.Length; - var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length); + var dataLength = memoryBufferWriter.Length; + var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length); var array = new byte[dataLength + prefixLength]; var span = array.AsSpan(); // Write length then message to output - var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span); + var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span); Debug.Assert(written == prefixLength); - writer.CopyTo(span.Slice(prefixLength)); + memoryBufferWriter.CopyTo(span.Slice(prefixLength)); return array; } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } - private void WriteMessageCore(HubMessage message, Stream packer) + private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer) { switch (message) { case InvocationMessage invocationMessage: - WriteInvocationMessage(invocationMessage, packer); + WriteInvocationMessage(invocationMessage, ref writer); break; case StreamInvocationMessage streamInvocationMessage: - WriteStreamInvocationMessage(streamInvocationMessage, packer); + WriteStreamInvocationMessage(streamInvocationMessage, ref writer); break; case StreamItemMessage streamItemMessage: - WriteStreamingItemMessage(streamItemMessage, packer); + WriteStreamingItemMessage(streamItemMessage, ref writer); break; case CompletionMessage completionMessage: - WriteCompletionMessage(completionMessage, packer); + WriteCompletionMessage(completionMessage, ref writer); break; case CancelInvocationMessage cancelInvocationMessage: - WriteCancelInvocationMessage(cancelInvocationMessage, packer); + WriteCancelInvocationMessage(cancelInvocationMessage, ref writer); break; case PingMessage pingMessage: - WritePingMessage(pingMessage, packer); + WritePingMessage(pingMessage, ref writer); break; case CloseMessage closeMessage: - WriteCloseMessage(closeMessage, packer); + WriteCloseMessage(closeMessage, ref writer); break; default: throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); } + + writer.Flush(); } - private void WriteInvocationMessage(InvocationMessage message, Stream packer) + private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 6); + writer.WriteArrayHeader(6); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType); - PackHeaders(packer, message.Headers); + writer.Write(HubProtocolConstants.InvocationMessageType); + PackHeaders(message.Headers, ref writer); if (string.IsNullOrEmpty(message.InvocationId)) { - MessagePackBinary.WriteNil(packer); + writer.WriteNil(); } else { - MessagePackBinary.WriteString(packer, message.InvocationId); + writer.Write(message.InvocationId); } - MessagePackBinary.WriteString(packer, message.Target); - MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); + writer.Write(message.Target); + writer.WriteArrayHeader(message.Arguments.Length); foreach (var arg in message.Arguments) { - WriteArgument(arg, packer); + WriteArgument(arg, ref writer); } - WriteStreamIds(message.StreamIds, packer); + WriteStreamIds(message.StreamIds, ref writer); } - private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer) + private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 6); + writer.WriteArrayHeader(6); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - MessagePackBinary.WriteString(packer, message.Target); + writer.Write(HubProtocolConstants.StreamInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(message.Target); - MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); + writer.WriteArrayHeader(message.Arguments.Length); foreach (var arg in message.Arguments) { - WriteArgument(arg, packer); + WriteArgument(arg, ref writer); } - WriteStreamIds(message.StreamIds, packer); + WriteStreamIds(message.StreamIds, ref writer); } - private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer) + private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 4); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamItemMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - WriteArgument(message.Item, packer); + writer.WriteArrayHeader(4); + writer.Write(HubProtocolConstants.StreamItemMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + WriteArgument(message.Item, ref writer); } - private void WriteArgument(object argument, Stream stream) + private void WriteArgument(object argument, ref MessagePackWriter writer) { if (argument == null) { - MessagePackBinary.WriteNil(stream); + writer.WriteNil(); } else { - MessagePackSerializer.NonGeneric.Serialize(argument.GetType(), stream, argument, _resolver); + MessagePackSerializer.Serialize(argument.GetType(), ref writer, argument, _msgPackSerializerOptions); } } - private void WriteStreamIds(string[] streamIds, Stream packer) + private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer) { if (streamIds != null) { - MessagePackBinary.WriteArrayHeader(packer, streamIds.Length); + writer.WriteArrayHeader(streamIds.Length); foreach (var streamId in streamIds) { - MessagePackBinary.WriteString(packer, streamId); + writer.Write(streamId); } } else { - MessagePackBinary.WriteArrayHeader(packer, 0); + writer.WriteArrayHeader(0); } } - private void WriteCompletionMessage(CompletionMessage message, Stream packer) + private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer) { var resultKind = message.Error != null ? ErrorResult : message.HasResult ? NonVoidResult : VoidResult; - MessagePackBinary.WriteArrayHeader(packer, 4 + (resultKind != VoidResult ? 1 : 0)); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.CompletionMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - MessagePackBinary.WriteInt32(packer, resultKind); + writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0)); + writer.Write(HubProtocolConstants.CompletionMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(resultKind); switch (resultKind) { case ErrorResult: - MessagePackBinary.WriteString(packer, message.Error); + writer.Write(message.Error); break; case NonVoidResult: - WriteArgument(message.Result, packer); + WriteArgument(message.Result, ref writer); break; } } - private void WriteCancelInvocationMessage(CancelInvocationMessage message, Stream packer) + private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CancelInvocationMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CancelInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); } - private void WriteCloseMessage(CloseMessage message, Stream packer) + private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CloseMessageType); + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CloseMessageType); if (string.IsNullOrEmpty(message.Error)) { - MessagePackBinary.WriteNil(packer); + writer.WriteNil(); } else { - MessagePackBinary.WriteString(packer, message.Error); + writer.Write(message.Error); } - MessagePackBinary.WriteBoolean(packer, message.AllowReconnect); + writer.Write(message.AllowReconnect); } - private void WritePingMessage(PingMessage pingMessage, Stream packer) + private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 1); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.PingMessageType); + writer.WriteArrayHeader(1); + writer.Write(HubProtocolConstants.PingMessageType); } - private void PackHeaders(Stream packer, IDictionary headers) + private void PackHeaders(IDictionary headers, ref MessagePackWriter writer) { if (headers != null) { - MessagePackBinary.WriteMapHeader(packer, headers.Count); + writer.WriteMapHeader(headers.Count); if (headers.Count > 0) { foreach (var header in headers) { - MessagePackBinary.WriteString(packer, header.Key); - MessagePackBinary.WriteString(packer, header.Value); + writer.Write(header.Key); + writer.Write(header.Value); } } } else { - MessagePackBinary.WriteMapHeader(packer, 0); + writer.WriteMapHeader(0); } } - private static string ReadInvocationId(byte[] input, ref int offset) - { - return ReadString(input, ref offset, "invocationId"); - } + private static string ReadInvocationId(ref MessagePackReader reader) => + ReadString(ref reader, "invocationId"); - private static bool ReadBoolean(byte[] input, ref int offset, string field) + private static bool ReadBoolean(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize); - offset += readSize; - return readBool; + return reader.ReadBoolean(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException); } - private static int ReadInt32(byte[] input, ref int offset, string field) + private static int ReadInt32(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize); - offset += readSize; - return readInt; + return reader.ReadInt32(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException); } - private static string ReadString(byte[] input, ref int offset, string field) + private static string ReadString(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readString = MessagePackBinary.ReadString(input, offset, out var readSize); - offset += readSize; - return readString; + return reader.ReadString(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException); } - private static long ReadMapLength(byte[] input, ref int offset, string field) + private static long ReadMapLength(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize); - offset += readSize; - return readMap; + return reader.ReadMapHeader(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading map length for '{field}' failed.", ex); } - throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException); } - private static long ReadArrayLength(byte[] input, ref int offset, string field) + private static long ReadArrayLength(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize); - offset += readSize; - return readArray; + return reader.ReadArrayHeader(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading array length for '{field}' failed.", ex); } - - throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException); } - private static object DeserializeObject(byte[] input, ref int offset, Type type, string field, IFormatterResolver resolver) + private static object DeserializeObject(ref MessagePackReader reader, Type type, string field, MessagePackSerializerOptions msgPackSerializerOptions) { - Exception msgPackException = null; try { - var obj = MessagePackSerializer.NonGeneric.Deserialize(type, new ArraySegment(input, offset, input.Length - offset), resolver); - offset += MessagePackBinary.ReadNextBlock(input, offset); - return obj; + return MessagePackSerializer.Deserialize(type, ref reader, msgPackSerializerOptions); } catch (Exception ex) { - msgPackException = ex; + throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex); } - - throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException); } internal static List CreateDefaultFormatterResolvers() @@ -703,10 +664,10 @@ internal class SignalRResolver : IFormatterResolver { public static readonly IFormatterResolver Instance = new SignalRResolver(); - public static readonly IList Resolvers = new[] + public static readonly IList Resolvers = new IFormatterResolver[] { - MessagePack.Resolvers.DynamicEnumAsStringResolver.Instance, - MessagePack.Resolvers.ContractlessStandardResolver.Instance, + DynamicEnumAsStringResolver.Instance, + ContractlessStandardResolver.Instance, }; public IMessagePackFormatter GetFormatter() @@ -731,30 +692,5 @@ static Cache() } } } - - // Support for users making their own Formatter lists - internal class CombinedResolvers : IFormatterResolver - { - private readonly IList _resolvers; - - public CombinedResolvers(IList resolvers) - { - _resolvers = resolvers; - } - - public IMessagePackFormatter GetFormatter() - { - foreach (var resolver in _resolvers) - { - var formatter = resolver.GetFormatter(); - if (formatter != null) - { - return formatter; - } - } - - return null; - } - } } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index 8a87a67bd695..2355228bd9c0 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -26,15 +26,20 @@ public void SerializerCanSerializeTypesWithNoDefaultCtor() AssertMessages(new byte[] { ArrayBytes(5), 3, 0x80, StringBytes(1), (byte)'0', 0x03, ArrayBytes(1), 42 }, result); } - [Fact] - public void WriteAndParseDateTimeConvertsToUTC() + [Theory] + [InlineData(DateTimeKind.Utc)] + [InlineData(DateTimeKind.Local)] + [InlineData(DateTimeKind.Unspecified)] + public void WriteAndParseDateTimeConvertsToUTC(DateTimeKind dateTimeKind) { - var dateTime = new DateTime(2018, 4, 9); + // The messagepack Timestamp format always converts input DateTime to Utc if they are passed as "DateTimeKind.Local" : + // https://github.com/neuecc/MessagePack-CSharp/pull/520/files#diff-ed970b3daebc708ce49f55d418075979 + var originalDateTime = new DateTime(2018, 4, 9, 0, 0, 0, dateTimeKind); var writer = MemoryBufferWriter.Get(); try { - HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", dateTime), writer); + HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", originalDateTime), writer); var bytes = new ReadOnlySequence(writer.ToArray()); HubProtocol.TryParseMessage(ref bytes, new TestBinder(typeof(DateTime)), out var hubMessage); @@ -44,7 +49,10 @@ public void WriteAndParseDateTimeConvertsToUTC() // The messagepack Timestamp format specifies that time is stored as seconds since 1970-01-01 UTC // so the library has no choice but to store the time as UTC // https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type - Assert.Equal(dateTime.ToUniversalTime(), resultDateTime); + // So If the original DateTiem was a "Local" one, we create a new DateTime equivalent to the original one but converted to Utc + var expectedUtcDateTime = (originalDateTime.Kind == DateTimeKind.Local) ? originalDateTime.ToUniversalTime() : originalDateTime; + + Assert.Equal(expectedUtcDateTime, resultDateTime); } finally { diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 3d9c97c1ac19..4f2e9e3cd34c 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2512,33 +2512,15 @@ public IMessagePackFormatter GetFormatter() private class StringFormatter : IMessagePackFormatter { - public T Deserialize(byte[] bytes, int offset, IFormatterResolver formatterResolver, out int readSize) + public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) { // this method isn't used in our tests - readSize = 0; return default; } - public int Serialize(ref byte[] bytes, int offset, T value, IFormatterResolver formatterResolver) - { - // string of size 15 - bytes[offset] = 0xAF; - bytes[offset + 1] = (byte)'f'; - bytes[offset + 2] = (byte)'o'; - bytes[offset + 3] = (byte)'r'; - bytes[offset + 4] = (byte)'m'; - bytes[offset + 5] = (byte)'a'; - bytes[offset + 6] = (byte)'t'; - bytes[offset + 7] = (byte)'t'; - bytes[offset + 8] = (byte)'e'; - bytes[offset + 9] = (byte)'d'; - bytes[offset + 10] = (byte)'S'; - bytes[offset + 11] = (byte)'t'; - bytes[offset + 12] = (byte)'r'; - bytes[offset + 13] = (byte)'i'; - bytes[offset + 14] = (byte)'n'; - bytes[offset + 15] = (byte)'g'; - return 16; + public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializerOptions options) + { + writer.Write("formattedString"); } } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs deleted file mode 100644 index 7780bca98850..000000000000 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Diagnostics; -using System.Runtime.InteropServices; -using MessagePack; - -namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal -{ - internal static class MessagePackUtil - { - public static int ReadArrayHeader(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static int ReadMapHeader(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static string ReadString(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static byte[] ReadBytes(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static int ReadInt32(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static byte ReadByte(ref ReadOnlyMemory data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - private static ArraySegment GetArray(ReadOnlyMemory data) - { - var isArray = MemoryMarshal.TryGetArray(data, out var array); - Debug.Assert(isArray); - return array; - } - } -} diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 24426184895c..3126b52f5d49 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -43,30 +44,33 @@ public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList 0) { - MessagePackBinary.WriteArrayHeader(writer, excludedConnectionIds.Count); + writer.WriteArrayHeader(excludedConnectionIds.Count); foreach (var id in excludedConnectionIds) { - MessagePackBinary.WriteString(writer, id); + writer.Write(id); } } else { - MessagePackBinary.WriteArrayHeader(writer, 0); + writer.WriteArrayHeader(0); } - WriteHubMessage(writer, new InvocationMessage(methodName, args)); - return writer.ToArray(); + WriteHubMessage(ref writer, new InvocationMessage(methodName, args)); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } @@ -80,21 +84,24 @@ public byte[] WriteGroupCommand(RedisGroupCommand command) // * A 'str': The connection Id // Any additional items are discarded. - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { - MessagePackBinary.WriteArrayHeader(writer, 5); - MessagePackBinary.WriteInt32(writer, command.Id); - MessagePackBinary.WriteString(writer, command.ServerName); - MessagePackBinary.WriteByte(writer, (byte)command.Action); - MessagePackBinary.WriteString(writer, command.GroupName); - MessagePackBinary.WriteString(writer, command.ConnectionId); - - return writer.ToArray(); + var writer = new MessagePackWriter(memoryBufferWriter); + + writer.WriteArrayHeader(5); + writer.Write(command.Id); + writer.Write(command.ServerName); + writer.Write((byte)command.Action); + writer.Write(command.GroupName); + writer.Write(command.ConnectionId); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } @@ -104,101 +111,110 @@ public byte[] WriteAck(int messageId) // * An 'int': The Id of the command being acknowledged. // Any additional items are discarded. - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { - MessagePackBinary.WriteArrayHeader(writer, 1); - MessagePackBinary.WriteInt32(writer, messageId); + var writer = new MessagePackWriter(memoryBufferWriter); - return writer.ToArray(); + writer.WriteArrayHeader(1); + writer.Write(messageId); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } public RedisInvocation ReadInvocation(ReadOnlyMemory data) { // See WriteInvocation for the format - ValidateArraySize(ref data, 2, "Invocation"); + var reader = new MessagePackReader(data); + ValidateArraySize(ref reader, 2, "Invocation"); // Read excluded Ids IReadOnlyList excludedConnectionIds = null; - var idCount = MessagePackUtil.ReadArrayHeader(ref data); + var idCount = reader.ReadArrayHeader(); if (idCount > 0) { var ids = new string[idCount]; for (var i = 0; i < idCount; i++) { - ids[i] = MessagePackUtil.ReadString(ref data); + ids[i] = reader.ReadString(); } excludedConnectionIds = ids; } // Read payload - var message = ReadSerializedHubMessage(ref data); + var message = ReadSerializedHubMessage(ref reader); return new RedisInvocation(message, excludedConnectionIds); } public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data) { + var reader = new MessagePackReader(data); + // See WriteGroupCommand for format. - ValidateArraySize(ref data, 5, "GroupCommand"); + ValidateArraySize(ref reader, 5, "GroupCommand"); - var id = MessagePackUtil.ReadInt32(ref data); - var serverName = MessagePackUtil.ReadString(ref data); - var action = (GroupAction)MessagePackUtil.ReadByte(ref data); - var groupName = MessagePackUtil.ReadString(ref data); - var connectionId = MessagePackUtil.ReadString(ref data); + var id = reader.ReadInt32(); + var serverName = reader.ReadString(); + var action = (GroupAction)reader.ReadByte(); + var groupName = reader.ReadString(); + var connectionId = reader.ReadString(); return new RedisGroupCommand(id, serverName, action, groupName, connectionId); } public int ReadAck(ReadOnlyMemory data) { + var reader = new MessagePackReader(data); + // See WriteAck for format - ValidateArraySize(ref data, 1, "Ack"); - return MessagePackUtil.ReadInt32(ref data); + ValidateArraySize(ref reader, 1, "Ack"); + return reader.ReadInt32(); } - private void WriteHubMessage(Stream stream, HubMessage message) + private void WriteHubMessage(ref MessagePackWriter writer, HubMessage message) { // Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str') // and the values are the serialized blob (as a MessagePack 'bin'). var serializedHubMessages = _messageSerializer.SerializeMessage(message); - MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count); + writer.WriteMapHeader(serializedHubMessages.Count); foreach (var serializedMessage in serializedHubMessages) { - MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName); + writer.Write(serializedMessage.ProtocolName); var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array); Debug.Assert(isArray); - MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count); + writer.Write(array); } } - public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory data) + public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReader reader) { - var count = MessagePackUtil.ReadMapHeader(ref data); + var count = reader.ReadMapHeader(); var serializations = new SerializedMessage[count]; for (var i = 0; i < count; i++) { - var protocol = MessagePackUtil.ReadString(ref data); - var serialized = MessagePackUtil.ReadBytes(ref data); + var protocol = reader.ReadString(); + var serialized = reader.ReadBytes()?.ToArray() ?? Array.Empty(); + serializations[i] = new SerializedMessage(protocol, serialized); } return new SerializedHubMessage(serializations); } - private static void ValidateArraySize(ref ReadOnlyMemory data, int expectedLength, string messageType) + private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) { - var length = MessagePackUtil.ReadArrayHeader(ref data); + var length = reader.ReadArrayHeader(); if (length < expectedLength) {