From f078e6f9d2d92de019818427698b7aa746398c3e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 24 Feb 2020 22:55:58 +0000 Subject: [PATCH] Sync changes from runtime --- .../runtime/Http2/Hpack/HPackDecoder.cs | 2 +- src/Shared/runtime/Http2/Hpack/HeaderField.cs | 7 +- .../runtime/Http3/QPack/QPackDecoder.cs | 8 + src/Shared/runtime/NetEventSource.Common.cs | 738 ++++++++++++ .../Implementations/Mock/MockConnection.cs | 226 ++++ .../Mock/MockImplementationProvider.cs | 21 + .../Quic/Implementations/Mock/MockListener.cs | 120 ++ .../Quic/Implementations/Mock/MockStream.cs | 259 ++++ .../MsQuic/Internal/MsQuicAddressHelpers.cs | 85 ++ .../MsQuic/Internal/MsQuicApi.cs | 361 ++++++ .../MsQuic/Internal/MsQuicParameterHelpers.cs | 98 ++ .../MsQuic/Internal/MsQuicSecurityConfig.cs | 45 + .../MsQuic/Internal/MsQuicSession.cs | 156 +++ .../MsQuic/Internal/QuicExceptionHelpers.cs | 17 + .../Internal/ResettableCompletionSource.cs | 81 ++ .../MsQuic/MsQuicConnection.cs | 416 +++++++ .../MsQuic/MsQuicImplementationProvider.cs | 22 + .../Implementations/MsQuic/MsQuicListener.cs | 213 ++++ .../Implementations/MsQuic/MsQuicStream.cs | 1042 +++++++++++++++++ .../Implementations/QuicConnectionProvider.cs | 36 + .../QuicImplementationProvider.cs | 17 + .../Implementations/QuicListenerProvider.cs | 22 + .../Implementations/QuicStreamProvider.cs | 53 + .../runtime/Quic/Interop/Interop.MsQuic.cs | 16 + .../runtime/Quic/Interop/MsQuicEnums.cs | 167 +++ .../Quic/Interop/MsQuicNativeMethods.cs | 488 ++++++++ .../runtime/Quic/Interop/MsQuicStatusCodes.cs | 121 ++ .../Quic/Interop/MsQuicStatusHelper.cs | 26 + .../runtime/Quic/NetEventSource.Quic.cs | 13 + .../Quic/QuicClientConnectionOptions.cs | 50 + src/Shared/runtime/Quic/QuicConnection.cs | 99 ++ .../Quic/QuicConnectionAbortedException.cs | 22 + src/Shared/runtime/Quic/QuicException.cs | 14 + .../Quic/QuicImplementationProviders.cs | 13 + src/Shared/runtime/Quic/QuicListener.cs | 55 + .../runtime/Quic/QuicListenerOptions.cs | 59 + .../Quic/QuicOperationAbortedException.cs | 18 + src/Shared/runtime/Quic/QuicStream.cs | 133 +++ .../Quic/QuicStreamAbortedException.cs | 22 + src/Shared/runtime/SR.Quic.cs | 20 + src/Shared/runtime/SR.resx | 12 + 41 files changed, 5387 insertions(+), 6 deletions(-) create mode 100644 src/Shared/runtime/NetEventSource.Common.cs create mode 100644 src/Shared/runtime/Quic/Implementations/Mock/MockConnection.cs create mode 100644 src/Shared/runtime/Quic/Implementations/Mock/MockImplementationProvider.cs create mode 100644 src/Shared/runtime/Quic/Implementations/Mock/MockListener.cs create mode 100644 src/Shared/runtime/Quic/Implementations/Mock/MockStream.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSecurityConfig.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSession.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicConnection.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicImplementationProvider.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicListener.cs create mode 100644 src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicStream.cs create mode 100644 src/Shared/runtime/Quic/Implementations/QuicConnectionProvider.cs create mode 100644 src/Shared/runtime/Quic/Implementations/QuicImplementationProvider.cs create mode 100644 src/Shared/runtime/Quic/Implementations/QuicListenerProvider.cs create mode 100644 src/Shared/runtime/Quic/Implementations/QuicStreamProvider.cs create mode 100644 src/Shared/runtime/Quic/Interop/Interop.MsQuic.cs create mode 100644 src/Shared/runtime/Quic/Interop/MsQuicEnums.cs create mode 100644 src/Shared/runtime/Quic/Interop/MsQuicNativeMethods.cs create mode 100644 src/Shared/runtime/Quic/Interop/MsQuicStatusCodes.cs create mode 100644 src/Shared/runtime/Quic/Interop/MsQuicStatusHelper.cs create mode 100644 src/Shared/runtime/Quic/NetEventSource.Quic.cs create mode 100644 src/Shared/runtime/Quic/QuicClientConnectionOptions.cs create mode 100644 src/Shared/runtime/Quic/QuicConnection.cs create mode 100644 src/Shared/runtime/Quic/QuicConnectionAbortedException.cs create mode 100644 src/Shared/runtime/Quic/QuicException.cs create mode 100644 src/Shared/runtime/Quic/QuicImplementationProviders.cs create mode 100644 src/Shared/runtime/Quic/QuicListener.cs create mode 100644 src/Shared/runtime/Quic/QuicListenerOptions.cs create mode 100644 src/Shared/runtime/Quic/QuicOperationAbortedException.cs create mode 100644 src/Shared/runtime/Quic/QuicStream.cs create mode 100644 src/Shared/runtime/Quic/QuicStreamAbortedException.cs create mode 100644 src/Shared/runtime/SR.Quic.cs diff --git a/src/Shared/runtime/Http2/Hpack/HPackDecoder.cs b/src/Shared/runtime/Http2/Hpack/HPackDecoder.cs index 98fb41652376..997047f2c683 100644 --- a/src/Shared/runtime/Http2/Hpack/HPackDecoder.cs +++ b/src/Shared/runtime/Http2/Hpack/HPackDecoder.cs @@ -12,7 +12,7 @@ namespace System.Net.Http.HPack { internal class HPackDecoder { - private enum State + private enum State : byte { Ready, HeaderFieldIndex, diff --git a/src/Shared/runtime/Http2/Hpack/HeaderField.cs b/src/Shared/runtime/Http2/Hpack/HeaderField.cs index 1eba82412d6f..f8762046713a 100644 --- a/src/Shared/runtime/Http2/Hpack/HeaderField.cs +++ b/src/Shared/runtime/Http2/Hpack/HeaderField.cs @@ -21,11 +21,8 @@ public HeaderField(ReadOnlySpan name, ReadOnlySpan value) // We should revisit our allocation strategy here so we don't need to allocate per entry // and we have a cap to how much allocation can happen per dynamic table // (without limiting the number of table entries a server can provide within the table size limit). - Name = new byte[name.Length]; - name.CopyTo(Name); - - Value = new byte[value.Length]; - value.CopyTo(Value); + Name = name.ToArray(); + Value = value.ToArray(); } public byte[] Name { get; } diff --git a/src/Shared/runtime/Http3/QPack/QPackDecoder.cs b/src/Shared/runtime/Http3/QPack/QPackDecoder.cs index 6f63d66ce9df..958dfac303e8 100644 --- a/src/Shared/runtime/Http3/QPack/QPackDecoder.cs +++ b/src/Shared/runtime/Http3/QPack/QPackDecoder.cs @@ -269,6 +269,10 @@ private void OnByte(byte b, IHttpHeadersHandler handler) if (_integerDecoder.BeginTryDecode((byte)prefixInt, LiteralHeaderFieldWithoutNameReferencePrefix, out intResult)) { + if (intResult == 0) + { + throw new QPackDecodingException(SR.Format(SR.net_http_invalid_header_name, "")); + } OnStringLength(intResult, State.HeaderName); } else @@ -303,6 +307,10 @@ private void OnByte(byte b, IHttpHeadersHandler handler) case State.HeaderNameLength: if (_integerDecoder.TryDecode(b, out intResult)) { + if (intResult == 0) + { + throw new QPackDecodingException(SR.Format(SR.net_http_invalid_header_name, "")); + } OnStringLength(intResult, nextState: State.HeaderName); } break; diff --git a/src/Shared/runtime/NetEventSource.Common.cs b/src/Shared/runtime/NetEventSource.Common.cs new file mode 100644 index 000000000000..46cd2ee685c8 --- /dev/null +++ b/src/Shared/runtime/NetEventSource.Common.cs @@ -0,0 +1,738 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#if DEBUG +// Uncomment to enable runtime checks to help validate that NetEventSource isn't being misused +// in a way that will cause performance problems, e.g. unexpected boxing of value types. +//#define DEBUG_NETEVENTSOURCE_MISUSE +#endif + +#nullable enable +using System.Collections; +using System.Diagnostics; +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +#if NET46 +using System.Security; +#endif + +#pragma warning disable CA1823 // not all IDs are used by all partial providers + +namespace System.Net +{ + // Implementation: + // This partial file is meant to be consumed into each System.Net.* assembly that needs to log. Each such assembly also provides + // its own NetEventSource partial class that adds an appropriate [EventSource] attribute, giving it a unique name for that assembly. + // Those partials can then also add additional events if needed, starting numbering from the NextAvailableEventId defined by this partial. + + // Usage: + // - Operations that may allocate (e.g. boxing a value type, using string interpolation, etc.) or that may have computations + // at call sites should guard access like: + // if (NetEventSource.IsEnabled) NetEventSource.Enter(this, refArg1, valueTypeArg2); // entering an instance method with a value type arg + // if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"Found certificate: {cert}"); // info logging with a formattable string + // - Operations that have zero allocations / measurable computations at call sites can use a simpler pattern, calling methods like: + // NetEventSource.Enter(this); // entering an instance method + // NetEventSource.Info(this, "literal string"); // arbitrary message with a literal string + // NetEventSource.Enter(this, refArg1, regArg2); // entering an instance method with two reference type arguments + // NetEventSource.Enter(null); // entering a static method + // NetEventSource.Enter(null, refArg1); // entering a static method with one reference type argument + // Debug.Asserts inside the logging methods will help to flag some misuse if the DEBUG_NETEVENTSOURCE_MISUSE compilation constant is defined. + // However, because it can be difficult by observation to understand all of the costs involved, guarding can be done everywhere. + // - NetEventSource.Fail calls typically do not need to be prefixed with an IsEnabled check, even if they allocate, as FailMessage + // should only be used in cases similar to Debug.Fail, where they are not expected to happen in retail builds, and thus extra costs + // don't matter. + // - Messages can be strings, formattable strings, or any other object. Objects (including those used in formattable strings) have special + // formatting applied, controlled by the Format method. Partial specializations can also override this formatting by implementing a partial + // method that takes an object and optionally provides a string representation of it, in case a particular library wants to customize further. + + /// Provides logging facilities for System.Net libraries. +#if NET46 + [SecuritySafeCritical] +#endif + internal sealed partial class NetEventSource : EventSource + { + /// The single event source instance to use for all logging. + public static readonly NetEventSource Log = new NetEventSource(); + + #region Metadata + public class Keywords + { + public const EventKeywords Default = (EventKeywords)0x0001; + public const EventKeywords Debug = (EventKeywords)0x0002; + public const EventKeywords EnterExit = (EventKeywords)0x0004; + } + + private const string MissingMember = "(?)"; + private const string NullInstance = "(null)"; + private const string StaticMethodObject = "(static)"; + private const string NoParameters = ""; + private const int MaxDumpSize = 1024; + + private const int EnterEventId = 1; + private const int ExitEventId = 2; + private const int AssociateEventId = 3; + private const int InfoEventId = 4; + private const int ErrorEventId = 5; + private const int CriticalFailureEventId = 6; + private const int DumpArrayEventId = 7; + + // These events are implemented in NetEventSource.Security.cs. + // Define the ids here so that projects that include NetEventSource.Security.cs will not have conflicts. + private const int EnumerateSecurityPackagesId = 8; + private const int SspiPackageNotFoundId = 9; + private const int AcquireDefaultCredentialId = 10; + private const int AcquireCredentialsHandleId = 11; + private const int InitializeSecurityContextId = 12; + private const int SecurityContextInputBufferId = 13; + private const int SecurityContextInputBuffersId = 14; + private const int AcceptSecuritContextId = 15; + private const int OperationReturnedSomethingId = 16; + + private const int NextAvailableEventId = 17; // Update this value whenever new events are added. Derived types should base all events off of this to avoid conflicts. + #endregion + + #region Events + #region Enter + /// Logs entrance to a method. + /// `this`, or another object that serves to provide context for the operation. + /// A description of the entrance, including any arguments to the call. + /// The calling member. + [NonEvent] + public static void Enter(object? thisOrContextObject, FormattableString? formattableString = null, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(formattableString); + if (IsEnabled) Log.Enter(IdOf(thisOrContextObject), memberName, formattableString != null ? Format(formattableString) : NoParameters); + } + + /// Logs entrance to a method. + /// `this`, or another object that serves to provide context for the operation. + /// The object to log. + /// The calling member. + [NonEvent] + public static void Enter(object? thisOrContextObject, object arg0, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(arg0); + if (IsEnabled) Log.Enter(IdOf(thisOrContextObject), memberName, $"({Format(arg0)})"); + } + + /// Logs entrance to a method. + /// `this`, or another object that serves to provide context for the operation. + /// The first object to log. + /// The second object to log. + /// The calling member. + [NonEvent] + public static void Enter(object? thisOrContextObject, object arg0, object arg1, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(arg0); + DebugValidateArg(arg1); + if (IsEnabled) Log.Enter(IdOf(thisOrContextObject), memberName, $"({Format(arg0)}, {Format(arg1)})"); + } + + /// Logs entrance to a method. + /// `this`, or another object that serves to provide context for the operation. + /// The first object to log. + /// The second object to log. + /// The third object to log. + /// The calling member. + [NonEvent] + public static void Enter(object? thisOrContextObject, object arg0, object arg1, object arg2, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(arg0); + DebugValidateArg(arg1); + DebugValidateArg(arg2); + if (IsEnabled) Log.Enter(IdOf(thisOrContextObject), memberName, $"({Format(arg0)}, {Format(arg1)}, {Format(arg2)})"); + } + + [Event(EnterEventId, Level = EventLevel.Informational, Keywords = Keywords.EnterExit)] + private void Enter(string thisOrContextObject, string? memberName, string parameters) => + WriteEvent(EnterEventId, thisOrContextObject, memberName ?? MissingMember, parameters); + #endregion + + #region Exit + /// Logs exit from a method. + /// `this`, or another object that serves to provide context for the operation. + /// A description of the exit operation, including any return values. + /// The calling member. + [NonEvent] + public static void Exit(object? thisOrContextObject, FormattableString? formattableString = null, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(formattableString); + if (IsEnabled) Log.Exit(IdOf(thisOrContextObject), memberName, formattableString != null ? Format(formattableString) : NoParameters); + } + + /// Logs exit from a method. + /// `this`, or another object that serves to provide context for the operation. + /// A return value from the member. + /// The calling member. + [NonEvent] + public static void Exit(object? thisOrContextObject, object arg0, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(arg0); + if (IsEnabled) Log.Exit(IdOf(thisOrContextObject), memberName, Format(arg0).ToString()); + } + + /// Logs exit from a method. + /// `this`, or another object that serves to provide context for the operation. + /// A return value from the member. + /// A second return value from the member. + /// The calling member. + [NonEvent] + public static void Exit(object? thisOrContextObject, object arg0, object arg1, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(arg0); + DebugValidateArg(arg1); + if (IsEnabled) Log.Exit(IdOf(thisOrContextObject), memberName, $"{Format(arg0)}, {Format(arg1)}"); + } + + [Event(ExitEventId, Level = EventLevel.Informational, Keywords = Keywords.EnterExit)] + private void Exit(string thisOrContextObject, string? memberName, string? result) => + WriteEvent(ExitEventId, thisOrContextObject, memberName ?? MissingMember, result); + #endregion + + #region Info + /// Logs an information message. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Info(object? thisOrContextObject, FormattableString? formattableString = null, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(formattableString); + if (IsEnabled) Log.Info(IdOf(thisOrContextObject), memberName, formattableString != null ? Format(formattableString) : NoParameters); + } + + /// Logs an information message. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Info(object? thisOrContextObject, object? message, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(message); + if (IsEnabled) Log.Info(IdOf(thisOrContextObject), memberName, Format(message).ToString()); + } + + [Event(InfoEventId, Level = EventLevel.Informational, Keywords = Keywords.Default)] + private void Info(string thisOrContextObject, string? memberName, string? message) => + WriteEvent(InfoEventId, thisOrContextObject, memberName ?? MissingMember, message); + #endregion + + #region Error + /// Logs an error message. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Error(object? thisOrContextObject, FormattableString formattableString, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(formattableString); + if (IsEnabled) Log.ErrorMessage(IdOf(thisOrContextObject), memberName, Format(formattableString)); + } + + /// Logs an error message. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Error(object? thisOrContextObject, object message, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(message); + if (IsEnabled) Log.ErrorMessage(IdOf(thisOrContextObject), memberName, Format(message).ToString()); + } + + [Event(ErrorEventId, Level = EventLevel.Error, Keywords = Keywords.Default)] + private void ErrorMessage(string thisOrContextObject, string? memberName, string? message) => + WriteEvent(ErrorEventId, thisOrContextObject, memberName ?? MissingMember, message); + #endregion + + #region Fail + /// Logs a fatal error and raises an assert. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Fail(object? thisOrContextObject, FormattableString formattableString, [CallerMemberName] string? memberName = null) + { + // Don't call DebugValidateArg on args, as we expect Fail to be used in assert/failure situations + // that should never happen in production, and thus we don't care about extra costs. + + if (IsEnabled) Log.CriticalFailure(IdOf(thisOrContextObject), memberName, Format(formattableString)); + Debug.Fail(Format(formattableString), $"{IdOf(thisOrContextObject)}.{memberName}"); + } + + /// Logs a fatal error and raises an assert. + /// `this`, or another object that serves to provide context for the operation. + /// The message to be logged. + /// The calling member. + [NonEvent] + public static void Fail(object? thisOrContextObject, object message, [CallerMemberName] string? memberName = null) + { + // Don't call DebugValidateArg on args, as we expect Fail to be used in assert/failure situations + // that should never happen in production, and thus we don't care about extra costs. + + if (IsEnabled) Log.CriticalFailure(IdOf(thisOrContextObject), memberName, Format(message).ToString()); + Debug.Fail(Format(message).ToString(), $"{IdOf(thisOrContextObject)}.{memberName}"); + } + + [Event(CriticalFailureEventId, Level = EventLevel.Critical, Keywords = Keywords.Debug)] + private void CriticalFailure(string thisOrContextObject, string? memberName, string? message) => + WriteEvent(CriticalFailureEventId, thisOrContextObject, memberName ?? MissingMember, message); + #endregion + + #region DumpBuffer + /// Logs the contents of a buffer. + /// `this`, or another object that serves to provide context for the operation. + /// The buffer to be logged. + /// The calling member. + [NonEvent] + public static void DumpBuffer(object? thisOrContextObject, byte[] buffer, [CallerMemberName] string? memberName = null) + { + DumpBuffer(thisOrContextObject, buffer, 0, buffer.Length, memberName); + } + + /// Logs the contents of a buffer. + /// `this`, or another object that serves to provide context for the operation. + /// The buffer to be logged. + /// The starting offset from which to log. + /// The number of bytes to log. + /// The calling member. + [NonEvent] + public static void DumpBuffer(object? thisOrContextObject, byte[] buffer, int offset, int count, [CallerMemberName] string? memberName = null) + { + if (IsEnabled) + { + if (offset < 0 || offset > buffer.Length - count) + { + Fail(thisOrContextObject, $"Invalid {nameof(DumpBuffer)} Args. Length={buffer.Length}, Offset={offset}, Count={count}", memberName); + return; + } + + count = Math.Min(count, MaxDumpSize); + + byte[] slice = buffer; + if (offset != 0 || count != buffer.Length) + { + slice = new byte[count]; + Buffer.BlockCopy(buffer, offset, slice, 0, count); + } + + Log.DumpBuffer(IdOf(thisOrContextObject), memberName, slice); + } + } + + /// Logs the contents of a buffer. + /// `this`, or another object that serves to provide context for the operation. + /// The starting location of the buffer to be logged. + /// The number of bytes to log. + /// The calling member. + [NonEvent] + public static unsafe void DumpBuffer(object? thisOrContextObject, IntPtr bufferPtr, int count, [CallerMemberName] string? memberName = null) + { + Debug.Assert(bufferPtr != IntPtr.Zero); + Debug.Assert(count >= 0); + + if (IsEnabled) + { + var buffer = new byte[Math.Min(count, MaxDumpSize)]; + fixed (byte* targetPtr = buffer) + { + Buffer.MemoryCopy((byte*)bufferPtr, targetPtr, buffer.Length, buffer.Length); + } + Log.DumpBuffer(IdOf(thisOrContextObject), memberName, buffer); + } + } + + [Event(DumpArrayEventId, Level = EventLevel.Verbose, Keywords = Keywords.Debug)] + private unsafe void DumpBuffer(string thisOrContextObject, string? memberName, byte[] buffer) => + WriteEvent(DumpArrayEventId, thisOrContextObject, memberName ?? MissingMember, buffer); + #endregion + + #region Associate + /// Logs a relationship between two objects. + /// The first object. + /// The second object. + /// The calling member. + [NonEvent] + public static void Associate(object first, object second, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(first); + DebugValidateArg(second); + if (IsEnabled) Log.Associate(IdOf(first), memberName, IdOf(first), IdOf(second)); + } + + /// Logs a relationship between two objects. + /// `this`, or another object that serves to provide context for the operation. + /// The first object. + /// The second object. + /// The calling member. + [NonEvent] + public static void Associate(object? thisOrContextObject, object first, object second, [CallerMemberName] string? memberName = null) + { + DebugValidateArg(thisOrContextObject); + DebugValidateArg(first); + DebugValidateArg(second); + if (IsEnabled) Log.Associate(IdOf(thisOrContextObject), memberName, IdOf(first), IdOf(second)); + } + + [Event(AssociateEventId, Level = EventLevel.Informational, Keywords = Keywords.Default, Message = "[{2}]<-->[{3}]")] + private void Associate(string thisOrContextObject, string? memberName, string first, string second) => + WriteEvent(AssociateEventId, thisOrContextObject, memberName ?? MissingMember, first, second); + #endregion + #endregion + + #region Helpers + [Conditional("DEBUG_NETEVENTSOURCE_MISUSE")] + private static void DebugValidateArg(object? arg) + { + if (!IsEnabled) + { + Debug.Assert(!(arg is ValueType), $"Should not be passing value type {arg?.GetType()} to logging without IsEnabled check"); + Debug.Assert(!(arg is FormattableString), $"Should not be formatting FormattableString \"{arg}\" if tracing isn't enabled"); + } + } + + [Conditional("DEBUG_NETEVENTSOURCE_MISUSE")] + private static void DebugValidateArg(FormattableString? arg) + { + Debug.Assert(IsEnabled || arg == null, $"Should not be formatting FormattableString \"{arg}\" if tracing isn't enabled"); + } + + public static new bool IsEnabled => + Log.IsEnabled(); + + [NonEvent] + public static string IdOf(object? value) => value != null ? value.GetType().Name + "#" + GetHashCode(value) : NullInstance; + + [NonEvent] + public static int GetHashCode(object value) => value?.GetHashCode() ?? 0; + + [NonEvent] + public static object Format(object? value) + { + // If it's null, return a known string for null values + if (value == null) + { + return NullInstance; + } + + // Give another partial implementation a chance to provide its own string representation + string? result = null; + AdditionalCustomizedToString(value, ref result); + if (result != null) + { + return result; + } + + // Format arrays with their element type name and length + if (value is Array arr) + { + return $"{arr.GetType().GetElementType()}[{((Array)value).Length}]"; + } + + // Format ICollections as the name and count + if (value is ICollection c) + { + return $"{c.GetType().Name}({c.Count})"; + } + + // Format SafeHandles as their type, hash code, and pointer value + if (value is SafeHandle handle) + { + return $"{handle.GetType().Name}:{handle.GetHashCode()}(0x{handle.DangerousGetHandle():X})"; + } + + // Format IntPtrs as hex + if (value is IntPtr) + { + return $"0x{value:X}"; + } + + // If the string representation of the instance would just be its type name, + // use its id instead. + string? toString = value.ToString(); + if (toString == null || toString == value.GetType().FullName) + { + return IdOf(value); + } + + // Otherwise, return the original object so that the caller does default formatting. + return value; + } + + [NonEvent] + private static string Format(FormattableString s) + { + switch (s.ArgumentCount) + { + case 0: return s.Format; + case 1: return string.Format(s.Format, Format(s.GetArgument(0))); + case 2: return string.Format(s.Format, Format(s.GetArgument(0)), Format(s.GetArgument(1))); + case 3: return string.Format(s.Format, Format(s.GetArgument(0)), Format(s.GetArgument(1)), Format(s.GetArgument(2))); + default: + object?[] args = s.GetArguments(); + object[] formattedArgs = new object[args.Length]; + for (int i = 0; i < args.Length; i++) + { + formattedArgs[i] = Format(args[i]); + } + return string.Format(s.Format, formattedArgs); + } + } + + static partial void AdditionalCustomizedToString(T value, ref string? result); + #endregion + + #region Custom WriteEvent overloads + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, string? arg2, string? arg3, string? arg4) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + if (arg2 == null) arg2 = ""; + if (arg3 == null) arg3 = ""; + if (arg4 == null) arg4 = ""; + + fixed (char* string1Bytes = arg1) + fixed (char* string2Bytes = arg2) + fixed (char* string3Bytes = arg3) + fixed (char* string4Bytes = arg4) + { + const int NumEventDatas = 4; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)string1Bytes, + Size = ((arg1.Length + 1) * 2) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)string2Bytes, + Size = ((arg2.Length + 1) * 2) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)string3Bytes, + Size = ((arg3.Length + 1) * 2) + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)string4Bytes, + Size = ((arg4.Length + 1) * 2) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, string? arg2, byte[]? arg3) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + if (arg2 == null) arg2 = ""; + if (arg3 == null) arg3 = Array.Empty(); + + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + fixed (byte* arg3Ptr = arg3) + { + int bufferLength = arg3.Length; + const int NumEventDatas = 4; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)arg1Ptr, + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)arg2Ptr, + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(&bufferLength), + Size = 4 + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)arg3Ptr, + Size = bufferLength + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, int arg2, int arg3, int arg4) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + + fixed (char* arg1Ptr = arg1) + { + const int NumEventDatas = 4; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(&arg2), + Size = sizeof(int) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(&arg3), + Size = sizeof(int) + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)(&arg4), + Size = sizeof(int) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, int arg2, string? arg3) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + if (arg3 == null) arg3 = ""; + + fixed (char* arg1Ptr = arg1) + fixed (char* arg3Ptr = arg3) + { + const int NumEventDatas = 3; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(&arg2), + Size = sizeof(int) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(arg3Ptr), + Size = (arg3.Length + 1) * sizeof(char) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, string? arg2, int arg3) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + if (arg2 == null) arg2 = ""; + + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + { + const int NumEventDatas = 3; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(&arg3), + Size = sizeof(int) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + + [NonEvent] + private unsafe void WriteEvent(int eventId, string? arg1, string? arg2, string? arg3, int arg4) + { + if (IsEnabled()) + { + if (arg1 == null) arg1 = ""; + if (arg2 == null) arg2 = ""; + if (arg3 == null) arg3 = ""; + + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + fixed (char* arg3Ptr = arg3) + { + const int NumEventDatas = 4; + var descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(arg3Ptr), + Size = (arg3.Length + 1) * sizeof(char) + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)(&arg4), + Size = sizeof(int) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + } + #endregion + } +} diff --git a/src/Shared/runtime/Quic/Implementations/Mock/MockConnection.cs b/src/Shared/runtime/Quic/Implementations/Mock/MockConnection.cs new file mode 100644 index 000000000000..d9ab07022e72 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/Mock/MockConnection.cs @@ -0,0 +1,226 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers.Binary; +using System.Net.Security; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations.Mock +{ + internal sealed class MockConnection : QuicConnectionProvider + { + private readonly bool _isClient; + private bool _disposed = false; + private IPEndPoint _remoteEndPoint; + private IPEndPoint _localEndPoint; + private object _syncObject = new object(); + private Socket _socket = null; + private IPEndPoint _peerListenEndPoint = null; + private TcpListener _inboundListener = null; + private long _nextOutboundBidirectionalStream; + private long _nextOutboundUnidirectionalStream; + + // Constructor for outbound connections + internal MockConnection(IPEndPoint remoteEndPoint, SslClientAuthenticationOptions sslClientAuthenticationOptions, IPEndPoint localEndPoint = null) + { + _remoteEndPoint = remoteEndPoint; + _localEndPoint = localEndPoint; + + _isClient = true; + _nextOutboundBidirectionalStream = 0; + _nextOutboundUnidirectionalStream = 2; + } + + // Constructor for accepted inbound connections + internal MockConnection(Socket socket, IPEndPoint peerListenEndPoint, TcpListener inboundListener) + { + _isClient = false; + _nextOutboundBidirectionalStream = 1; + _nextOutboundUnidirectionalStream = 3; + _socket = socket; + _peerListenEndPoint = peerListenEndPoint; + _inboundListener = inboundListener; + _localEndPoint = (IPEndPoint)socket.LocalEndPoint; + _remoteEndPoint = (IPEndPoint)socket.RemoteEndPoint; + } + + internal override bool Connected + { + get + { + CheckDisposed(); + + return _socket != null; + } + } + + internal override IPEndPoint LocalEndPoint => new IPEndPoint(_localEndPoint.Address, _localEndPoint.Port); + + internal override IPEndPoint RemoteEndPoint => new IPEndPoint(_remoteEndPoint.Address, _remoteEndPoint.Port); + + internal override SslApplicationProtocol NegotiatedApplicationProtocol => throw new NotImplementedException(); + + internal override async ValueTask ConnectAsync(CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (Connected) + { + // TODO: Exception text + throw new InvalidOperationException("Already connected"); + } + + Socket socket = new Socket(_remoteEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(_remoteEndPoint).ConfigureAwait(false); + socket.NoDelay = true; + + _localEndPoint = (IPEndPoint)socket.LocalEndPoint; + + // Listen on a new local endpoint for inbound streams + TcpListener inboundListener = new TcpListener(_localEndPoint.Address, 0); + inboundListener.Start(); + int inboundListenPort = ((IPEndPoint)inboundListener.LocalEndpoint).Port; + + // Write inbound listen port to socket so server can read it + byte[] buffer = new byte[4]; + BinaryPrimitives.WriteInt32LittleEndian(buffer, inboundListenPort); + await socket.SendAsync(buffer, SocketFlags.None).ConfigureAwait(false); + + // Read first 4 bytes to get server listen port + int bytesRead = 0; + do + { + bytesRead += await socket.ReceiveAsync(buffer.AsMemory().Slice(bytesRead), SocketFlags.None).ConfigureAwait(false); + } while (bytesRead != buffer.Length); + + int peerListenPort = BinaryPrimitives.ReadInt32LittleEndian(buffer); + IPEndPoint peerListenEndPoint = new IPEndPoint(((IPEndPoint)socket.RemoteEndPoint).Address, peerListenPort); + + _socket = socket; + _peerListenEndPoint = peerListenEndPoint; + _inboundListener = inboundListener; + } + + internal override QuicStreamProvider OpenUnidirectionalStream() + { + long streamId; + lock (_syncObject) + { + streamId = _nextOutboundUnidirectionalStream; + _nextOutboundUnidirectionalStream += 4; + } + + return new MockStream(this, streamId, bidirectional: false); + } + + internal override QuicStreamProvider OpenBidirectionalStream() + { + long streamId; + lock (_syncObject) + { + streamId = _nextOutboundBidirectionalStream; + _nextOutboundBidirectionalStream += 4; + } + + return new MockStream(this, streamId, bidirectional: true); + } + + internal override long GetRemoteAvailableUnidirectionalStreamCount() + { + throw new NotImplementedException(); + } + + internal override long GetRemoteAvailableBidirectionalStreamCount() + { + throw new NotImplementedException(); + } + + internal async Task CreateOutboundMockStreamAsync(long streamId) + { + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(_peerListenEndPoint).ConfigureAwait(false); + socket.NoDelay = true; + + // Write stream ID to socket so server can read it + byte[] buffer = new byte[8]; + BinaryPrimitives.WriteInt64LittleEndian(buffer, streamId); + await socket.SendAsync(buffer, SocketFlags.None).ConfigureAwait(false); + + return socket; + } + + internal override async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) + { + CheckDisposed(); + + Socket socket = await _inboundListener.AcceptSocketAsync().ConfigureAwait(false); + + // Read first bytes to get stream ID + byte[] buffer = new byte[8]; + int bytesRead = 0; + do + { + bytesRead += await socket.ReceiveAsync(buffer.AsMemory().Slice(bytesRead), SocketFlags.None).ConfigureAwait(false); + } while (bytesRead != buffer.Length); + + long streamId = BinaryPrimitives.ReadInt64LittleEndian(buffer); + + bool clientInitiated = ((streamId & 0b01) == 0); + if (clientInitiated == _isClient) + { + throw new Exception($"Wrong initiator on accepted stream??? streamId={streamId}, _isClient={_isClient}"); + } + + bool bidirectional = ((streamId & 0b10) == 0); + return new MockStream(socket, streamId, bidirectional: bidirectional); + } + + internal override ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) + { + Dispose(); + return default; + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(QuicConnection)); + } + } + + private void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + _socket?.Dispose(); + _socket = null; + + _inboundListener?.Stop(); + _inboundListener = null; + } + + // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below. + // TODO: set large fields to null. + + _disposed = true; + } + } + + ~MockConnection() + { + Dispose(false); + } + + public override void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/Mock/MockImplementationProvider.cs b/src/Shared/runtime/Quic/Implementations/Mock/MockImplementationProvider.cs new file mode 100644 index 000000000000..b70a11328483 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/Mock/MockImplementationProvider.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Security; + +namespace System.Net.Quic.Implementations.Mock +{ + internal sealed class MockImplementationProvider : QuicImplementationProvider + { + internal override QuicListenerProvider CreateListener(QuicListenerOptions options) + { + return new MockListener(options.ListenEndPoint, options.ServerAuthenticationOptions); + } + + internal override QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options) + { + return new MockConnection(options.RemoteEndPoint, options.ClientAuthenticationOptions, options.LocalEndPoint); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/Mock/MockListener.cs b/src/Shared/runtime/Quic/Implementations/Mock/MockListener.cs new file mode 100644 index 000000000000..911f1896b169 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/Mock/MockListener.cs @@ -0,0 +1,120 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Sockets; +using System.Net.Security; +using System.Threading.Tasks; +using System.Threading; +using System.Buffers.Binary; + +namespace System.Net.Quic.Implementations.Mock +{ + internal sealed class MockListener : QuicListenerProvider + { + private bool _disposed = false; + private SslServerAuthenticationOptions _sslOptions; + private IPEndPoint _listenEndPoint; + private TcpListener _tcpListener = null; + + internal MockListener(IPEndPoint listenEndPoint, SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + if (listenEndPoint == null) + { + throw new ArgumentNullException(nameof(listenEndPoint)); + } + + _sslOptions = sslServerAuthenticationOptions; + _listenEndPoint = listenEndPoint; + + _tcpListener = new TcpListener(listenEndPoint); + } + + // IPEndPoint is mutable, so we must create a new instance every time this is retrieved. + internal override IPEndPoint ListenEndPoint => new IPEndPoint(_listenEndPoint.Address, _listenEndPoint.Port); + + internal override async ValueTask AcceptConnectionAsync(CancellationToken cancellationToken = default) + { + CheckDisposed(); + + Socket socket = await _tcpListener.AcceptSocketAsync().ConfigureAwait(false); + socket.NoDelay = true; + + // Read first 4 bytes to get client listen port + byte[] buffer = new byte[4]; + int bytesRead = 0; + do + { + bytesRead += await socket.ReceiveAsync(buffer.AsMemory().Slice(bytesRead), SocketFlags.None).ConfigureAwait(false); + } while (bytesRead != buffer.Length); + + int peerListenPort = BinaryPrimitives.ReadInt32LittleEndian(buffer); + IPEndPoint peerListenEndPoint = new IPEndPoint(((IPEndPoint)socket.RemoteEndPoint).Address, peerListenPort); + + // Listen on a new local endpoint for inbound streams + TcpListener inboundListener = new TcpListener(_listenEndPoint.Address, 0); + inboundListener.Start(); + int inboundListenPort = ((IPEndPoint)inboundListener.LocalEndpoint).Port; + + // Write inbound listen port to socket so client can read it + BinaryPrimitives.WriteInt32LittleEndian(buffer, inboundListenPort); + await socket.SendAsync(buffer, SocketFlags.None).ConfigureAwait(false); + + return new MockConnection(socket, peerListenEndPoint, inboundListener); + } + + internal override void Start() + { + CheckDisposed(); + + _tcpListener.Start(); + + if (_listenEndPoint.Port == 0) + { + // Get auto-assigned port + _listenEndPoint = (IPEndPoint)_tcpListener.LocalEndpoint; + } + } + + internal override void Close() + { + Dispose(); + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(QuicListener)); + } + } + + private void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + _tcpListener?.Stop(); + _tcpListener = null; + } + + // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below. + // TODO: set large fields to null. + + _disposed = true; + } + } + + ~MockListener() + { + Dispose(false); + } + + public override void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/Mock/MockStream.cs b/src/Shared/runtime/Quic/Implementations/Mock/MockStream.cs new file mode 100644 index 000000000000..187ba680e17d --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/Mock/MockStream.cs @@ -0,0 +1,259 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Diagnostics; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations.Mock +{ + internal sealed class MockStream : QuicStreamProvider + { + private bool _disposed = false; + private readonly long _streamId; + private bool _canRead; + private bool _canWrite; + + private MockConnection _connection; + + private Socket _socket = null; + + // Constructor for outbound streams + internal MockStream(MockConnection connection, long streamId, bool bidirectional) + { + _connection = connection; + _streamId = streamId; + _canRead = bidirectional; + _canWrite = true; + } + + // Constructor for inbound streams + internal MockStream(Socket socket, long streamId, bool bidirectional) + { + _socket = socket; + _streamId = streamId; + _canRead = true; + _canWrite = bidirectional; + } + + private async ValueTask ConnectAsync(CancellationToken cancellationToken = default) + { + Debug.Assert(_connection != null, "Stream not connected but no connection???"); + + _socket = await _connection.CreateOutboundMockStreamAsync(_streamId).ConfigureAwait(false); + + // Don't need to hold on to the connection any longer. + _connection = null; + } + + internal override long StreamId + { + get + { + CheckDisposed(); + return _streamId; + } + } + + internal override bool CanRead => _canRead; + + internal override int Read(Span buffer) + { + CheckDisposed(); + + if (!_canRead) + { + throw new NotSupportedException(); + } + + return _socket.Receive(buffer); + } + + internal override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (!_canRead) + { + throw new NotSupportedException(); + } + + if (_socket == null) + { + await ConnectAsync(cancellationToken).ConfigureAwait(false); + } + + return await _socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); + } + + internal override bool CanWrite => _canWrite; + + internal override void Write(ReadOnlySpan buffer) + { + CheckDisposed(); + + if (!_canWrite) + { + throw new NotSupportedException(); + } + + _socket.Send(buffer); + } + + internal override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return WriteAsync(buffer, endStream: false, cancellationToken); + } + + internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (!_canWrite) + { + throw new NotSupportedException(); + } + + if (_socket == null) + { + await ConnectAsync(cancellationToken).ConfigureAwait(false); + } + await _socket.SendAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); + + if (endStream) + { + _socket.Shutdown(SocketShutdown.Send); + } + } + + internal override ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default) + { + return WriteAsync(buffers, endStream: false, cancellationToken); + } + internal override async ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (!_canWrite) + { + throw new NotSupportedException(); + } + + if (_socket == null) + { + await ConnectAsync(cancellationToken).ConfigureAwait(false); + } + + foreach (ReadOnlyMemory buffer in buffers) + { + await _socket.SendAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); + } + + if (endStream) + { + _socket.Shutdown(SocketShutdown.Send); + } + } + + internal override ValueTask WriteAsync(ReadOnlyMemory> buffers, CancellationToken cancellationToken = default) + { + return WriteAsync(buffers, endStream: false, cancellationToken); + } + internal override async ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (!_canWrite) + { + throw new NotSupportedException(); + } + + if (_socket == null) + { + await ConnectAsync(cancellationToken).ConfigureAwait(false); + } + + foreach (ReadOnlyMemory buffer in buffers.ToArray()) + { + await _socket.SendAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); + } + + if (endStream) + { + _socket.Shutdown(SocketShutdown.Send); + } + } + + internal override void Flush() + { + CheckDisposed(); + } + + internal override Task FlushAsync(CancellationToken cancellationToken) + { + CheckDisposed(); + + return Task.CompletedTask; + } + + internal override void AbortRead(long errorCode) + { + throw new NotImplementedException(); + } + + internal override void AbortWrite(long errorCode) + { + throw new NotImplementedException(); + } + + + internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) + { + CheckDisposed(); + + return default; + } + + internal override void Shutdown() + { + CheckDisposed(); + + _socket.Shutdown(SocketShutdown.Send); + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(QuicStream)); + } + } + + public override void Dispose() + { + if (!_disposed) + { + _disposed = true; + + _socket?.Dispose(); + _socket = null; + } + } + + public override ValueTask DisposeAsync() + { + if (!_disposed) + { + _disposed = true; + + _socket?.Dispose(); + _socket = null; + } + + return default; + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs new file mode 100644 index 000000000000..2ecf0eb21065 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Sockets; +using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal static class MsQuicAddressHelpers + { + internal const ushort IPv4 = 2; + internal const ushort IPv6 = 23; + + internal static unsafe IPEndPoint INetToIPEndPoint(SOCKADDR_INET inetAddress) + { + if (inetAddress.si_family == IPv4) + { + return new IPEndPoint(new IPAddress(inetAddress.Ipv4.Address), (ushort)IPAddress.NetworkToHostOrder((short)inetAddress.Ipv4.sin_port)); + } + else + { + return new IPEndPoint(new IPAddress(inetAddress.Ipv6.Address), (ushort)IPAddress.NetworkToHostOrder((short)inetAddress.Ipv6._port)); + } + } + + internal static SOCKADDR_INET IPEndPointToINet(IPEndPoint endpoint) + { + SOCKADDR_INET socketAddress = default; + byte[] buffer = endpoint.Address.GetAddressBytes(); + if (endpoint.Address != IPAddress.Any && endpoint.Address != IPAddress.IPv6Any) + { + switch (endpoint.Address.AddressFamily) + { + case AddressFamily.InterNetwork: + socketAddress.Ipv4.sin_addr0 = buffer[0]; + socketAddress.Ipv4.sin_addr1 = buffer[1]; + socketAddress.Ipv4.sin_addr2 = buffer[2]; + socketAddress.Ipv4.sin_addr3 = buffer[3]; + socketAddress.Ipv4.sin_family = IPv4; + break; + case AddressFamily.InterNetworkV6: + socketAddress.Ipv6._addr0 = buffer[0]; + socketAddress.Ipv6._addr1 = buffer[1]; + socketAddress.Ipv6._addr2 = buffer[2]; + socketAddress.Ipv6._addr3 = buffer[3]; + socketAddress.Ipv6._addr4 = buffer[4]; + socketAddress.Ipv6._addr5 = buffer[5]; + socketAddress.Ipv6._addr6 = buffer[6]; + socketAddress.Ipv6._addr7 = buffer[7]; + socketAddress.Ipv6._addr8 = buffer[8]; + socketAddress.Ipv6._addr9 = buffer[9]; + socketAddress.Ipv6._addr10 = buffer[10]; + socketAddress.Ipv6._addr11 = buffer[11]; + socketAddress.Ipv6._addr12 = buffer[12]; + socketAddress.Ipv6._addr13 = buffer[13]; + socketAddress.Ipv6._addr14 = buffer[14]; + socketAddress.Ipv6._addr15 = buffer[15]; + socketAddress.Ipv6._family = IPv6; + break; + default: + throw new ArgumentException("Only IPv4 or IPv6 are supported"); + } + } + + SetPort(endpoint.Address.AddressFamily, ref socketAddress, endpoint.Port); + return socketAddress; + } + + private static void SetPort(AddressFamily addressFamily, ref SOCKADDR_INET socketAddrInet, int originalPort) + { + ushort convertedPort = (ushort)IPAddress.HostToNetworkOrder((short)originalPort); + switch (addressFamily) + { + case AddressFamily.InterNetwork: + socketAddrInet.Ipv4.sin_port = convertedPort; + break; + case AddressFamily.InterNetworkV6: + default: + socketAddrInet.Ipv6._port = convertedPort; + break; + } + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs new file mode 100644 index 000000000000..30e5cba6cbed --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs @@ -0,0 +1,361 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using System.Net.Security; +using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal class MsQuicApi : IDisposable + { + private bool _disposed; + + private readonly IntPtr _registrationContext; + + private unsafe MsQuicApi() + { + MsQuicNativeMethods.NativeApi* registration; + + try + { + uint status = Interop.MsQuic.MsQuicOpen(version: 1, out registration); + if (!MsQuicStatusHelper.SuccessfulStatusCode(status)) + { + throw new NotSupportedException(SR.net_quic_notsupported); + } + } + catch (DllNotFoundException) + { + throw new NotSupportedException(SR.net_quic_notsupported); + } + + MsQuicNativeMethods.NativeApi nativeRegistration = *registration; + + RegistrationOpenDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.RegistrationOpen); + RegistrationCloseDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.RegistrationClose); + + SecConfigCreateDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SecConfigCreate); + SecConfigDeleteDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SecConfigDelete); + SessionOpenDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SessionOpen); + SessionCloseDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SessionClose); + SessionShutdownDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SessionShutdown); + + ListenerOpenDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ListenerOpen); + ListenerCloseDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ListenerClose); + ListenerStartDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ListenerStart); + ListenerStopDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ListenerStop); + + ConnectionOpenDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ConnectionOpen); + ConnectionCloseDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ConnectionClose); + ConnectionShutdownDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ConnectionShutdown); + ConnectionStartDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.ConnectionStart); + + StreamOpenDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamOpen); + StreamCloseDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamClose); + StreamStartDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamStart); + StreamShutdownDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamShutdown); + StreamSendDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamSend); + StreamReceiveCompleteDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamReceiveComplete); + StreamReceiveSetEnabledDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.StreamReceiveSetEnabled); + SetContextDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SetContext); + GetContextDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.GetContext); + SetCallbackHandlerDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SetCallbackHandler); + + SetParamDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.SetParam); + GetParamDelegate = + Marshal.GetDelegateForFunctionPointer( + nativeRegistration.GetParam); + + RegistrationOpenDelegate(Encoding.UTF8.GetBytes("SystemNetQuic"), out IntPtr ctx); + _registrationContext = ctx; + } + + internal static MsQuicApi Api { get; } + + internal static bool IsQuicSupported { get; } + + static MsQuicApi() + { + // MsQuicOpen will succeed even if the platform will not support it. It will then fail with unspecified + // platform-specific errors in subsequent callbacks. For now, check for the minimum build we've tested it on. + + // TODO: + // - Hopefully, MsQuicOpen will perform this check for us and give us a consistent error code. + // - Otherwise, dial this in to reflect actual minimum requirements and add some sort of platform + // error code mapping when creating exceptions. + + OperatingSystem ver = Environment.OSVersion; + + if (ver.Platform == PlatformID.Win32NT && ver.Version < new Version(10, 0, 19041, 0)) + { + IsQuicSupported = false; + return; + } + + // TODO: try to initialize TLS 1.3 in SslStream. + + try + { + Api = new MsQuicApi(); + IsQuicSupported = true; + } + catch (NotSupportedException) + { + IsQuicSupported = false; + } + } + + internal MsQuicNativeMethods.RegistrationOpenDelegate RegistrationOpenDelegate { get; } + internal MsQuicNativeMethods.RegistrationCloseDelegate RegistrationCloseDelegate { get; } + + internal MsQuicNativeMethods.SecConfigCreateDelegate SecConfigCreateDelegate { get; } + internal MsQuicNativeMethods.SecConfigDeleteDelegate SecConfigDeleteDelegate { get; } + + internal MsQuicNativeMethods.SessionOpenDelegate SessionOpenDelegate { get; } + internal MsQuicNativeMethods.SessionCloseDelegate SessionCloseDelegate { get; } + internal MsQuicNativeMethods.SessionShutdownDelegate SessionShutdownDelegate { get; } + + internal MsQuicNativeMethods.ListenerOpenDelegate ListenerOpenDelegate { get; } + internal MsQuicNativeMethods.ListenerCloseDelegate ListenerCloseDelegate { get; } + internal MsQuicNativeMethods.ListenerStartDelegate ListenerStartDelegate { get; } + internal MsQuicNativeMethods.ListenerStopDelegate ListenerStopDelegate { get; } + + internal MsQuicNativeMethods.ConnectionOpenDelegate ConnectionOpenDelegate { get; } + internal MsQuicNativeMethods.ConnectionCloseDelegate ConnectionCloseDelegate { get; } + internal MsQuicNativeMethods.ConnectionShutdownDelegate ConnectionShutdownDelegate { get; } + internal MsQuicNativeMethods.ConnectionStartDelegate ConnectionStartDelegate { get; } + + internal MsQuicNativeMethods.StreamOpenDelegate StreamOpenDelegate { get; } + internal MsQuicNativeMethods.StreamCloseDelegate StreamCloseDelegate { get; } + internal MsQuicNativeMethods.StreamStartDelegate StreamStartDelegate { get; } + internal MsQuicNativeMethods.StreamShutdownDelegate StreamShutdownDelegate { get; } + internal MsQuicNativeMethods.StreamSendDelegate StreamSendDelegate { get; } + internal MsQuicNativeMethods.StreamReceiveCompleteDelegate StreamReceiveCompleteDelegate { get; } + internal MsQuicNativeMethods.StreamReceiveSetEnabledDelegate StreamReceiveSetEnabledDelegate { get; } + + internal MsQuicNativeMethods.SetContextDelegate SetContextDelegate { get; } + internal MsQuicNativeMethods.GetContextDelegate GetContextDelegate { get; } + internal MsQuicNativeMethods.SetCallbackHandlerDelegate SetCallbackHandlerDelegate { get; } + + internal MsQuicNativeMethods.SetParamDelegate SetParamDelegate { get; } + internal MsQuicNativeMethods.GetParamDelegate GetParamDelegate { get; } + + internal unsafe uint UnsafeSetParam( + IntPtr Handle, + uint Level, + uint Param, + MsQuicNativeMethods.QuicBuffer Buffer) + { + return SetParamDelegate( + Handle, + Level, + Param, + Buffer.Length, + Buffer.Buffer); + } + + internal unsafe uint UnsafeGetParam( + IntPtr Handle, + uint Level, + uint Param, + ref MsQuicNativeMethods.QuicBuffer Buffer) + { + uint bufferLength = Buffer.Length; + byte* buf = Buffer.Buffer; + return GetParamDelegate( + Handle, + Level, + Param, + &bufferLength, + buf); + } + + public async ValueTask CreateSecurityConfig(X509Certificate certificate, string certFilePath, string privateKeyFilePath) + { + MsQuicSecurityConfig secConfig = null; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + uint secConfigCreateStatus = MsQuicStatusCodes.InternalError; + uint createConfigStatus; + IntPtr unmanagedAddr = IntPtr.Zero; + MsQuicNativeMethods.CertFileParams fileParams = default; + + try + { + if (certFilePath != null && privateKeyFilePath != null) + { + fileParams = new MsQuicNativeMethods.CertFileParams + { + CertificateFilePath = Marshal.StringToCoTaskMemUTF8(certFilePath), + PrivateKeyFilePath = Marshal.StringToCoTaskMemUTF8(privateKeyFilePath) + }; + + unmanagedAddr = Marshal.AllocHGlobal(Marshal.SizeOf(fileParams)); + Marshal.StructureToPtr(fileParams, unmanagedAddr, fDeleteOld: false); + + createConfigStatus = SecConfigCreateDelegate( + _registrationContext, + (uint)QUIC_SEC_CONFIG_FLAG.CERT_FILE, + certificate.Handle, + null, + IntPtr.Zero, + SecCfgCreateCallbackHandler); + } + else if (certificate != null) + { + createConfigStatus = SecConfigCreateDelegate( + _registrationContext, + (uint)QUIC_SEC_CONFIG_FLAG.CERT_CONTEXT, + certificate.Handle, + null, + IntPtr.Zero, + SecCfgCreateCallbackHandler); + } + else + { + // If no certificate is provided, provide a null one. + createConfigStatus = SecConfigCreateDelegate( + _registrationContext, + (uint)QUIC_SEC_CONFIG_FLAG.CERT_NULL, + IntPtr.Zero, + null, + IntPtr.Zero, + SecCfgCreateCallbackHandler); + } + + QuicExceptionHelpers.ThrowIfFailed( + createConfigStatus, + "Could not create security configuration."); + + void SecCfgCreateCallbackHandler( + IntPtr context, + uint status, + IntPtr securityConfig) + { + secConfig = new MsQuicSecurityConfig(this, securityConfig); + secConfigCreateStatus = status; + tcs.SetResult(null); + } + + await tcs.Task.ConfigureAwait(false); + + QuicExceptionHelpers.ThrowIfFailed( + secConfigCreateStatus, + "Could not create security configuration."); + } + finally + { + if (fileParams.CertificateFilePath != IntPtr.Zero) + { + Marshal.FreeCoTaskMem(fileParams.CertificateFilePath); + } + + if (fileParams.PrivateKeyFilePath != IntPtr.Zero) + { + Marshal.FreeCoTaskMem(fileParams.PrivateKeyFilePath); + } + + if (unmanagedAddr != IntPtr.Zero) + { + Marshal.FreeHGlobal(unmanagedAddr); + } + } + + return secConfig; + } + + public IntPtr SessionOpen(byte[] alpn) + { + IntPtr sessionPtr = IntPtr.Zero; + + uint status = SessionOpenDelegate( + _registrationContext, + alpn, + IntPtr.Zero, + ref sessionPtr); + + QuicExceptionHelpers.ThrowIfFailed(status, "Could not open session."); + + return sessionPtr; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + ~MsQuicApi() + { + Dispose(disposing: false); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + RegistrationCloseDelegate?.Invoke(_registrationContext); + + _disposed = true; + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs new file mode 100644 index 000000000000..757bb0da0545 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal static class MsQuicParameterHelpers + { + internal static unsafe SOCKADDR_INET GetINetParam(MsQuicApi api, IntPtr nativeObject, uint level, uint param) + { + byte* ptr = stackalloc byte[sizeof(SOCKADDR_INET)]; + QuicBuffer buffer = new QuicBuffer + { + Length = (uint)sizeof(SOCKADDR_INET), + Buffer = ptr + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeGetParam(nativeObject, level, param, ref buffer), + "Could not get SOCKADDR_INET."); + + return *(SOCKADDR_INET*)ptr; + } + + internal static unsafe ushort GetUShortParam(MsQuicApi api, IntPtr nativeObject, uint level, uint param) + { + byte* ptr = stackalloc byte[sizeof(ushort)]; + QuicBuffer buffer = new QuicBuffer() + { + Length = sizeof(ushort), + Buffer = ptr + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeGetParam(nativeObject, level, param, ref buffer), + "Could not get ushort."); + + return *(ushort*)ptr; + } + + internal static unsafe void SetUshortParam(MsQuicApi api, IntPtr nativeObject, uint level, uint param, ushort value) + { + QuicBuffer buffer = new QuicBuffer() + { + Length = sizeof(ushort), + Buffer = (byte*)&value + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeSetParam(nativeObject, level, param, buffer), + "Could not set ushort."); + } + + internal static unsafe ulong GetULongParam(MsQuicApi api, IntPtr nativeObject, uint level, uint param) + { + byte* ptr = stackalloc byte[sizeof(ulong)]; + QuicBuffer buffer = new QuicBuffer() + { + Length = sizeof(ulong), + Buffer = ptr + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeGetParam(nativeObject, level, param, ref buffer), + "Could not get ulong."); + + return *(ulong*)ptr; + } + + internal static unsafe void SetULongParam(MsQuicApi api, IntPtr nativeObject, uint level, uint param, ulong value) + { + QuicBuffer buffer = new QuicBuffer() + { + Length = sizeof(ulong), + Buffer = (byte*)&value + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeGetParam(nativeObject, level, param, ref buffer), + "Could not set ulong."); + } + + internal static unsafe void SetSecurityConfig(MsQuicApi api, IntPtr nativeObject, uint level, uint param, IntPtr value) + { + QuicBuffer buffer = new QuicBuffer() + { + Length = (uint)sizeof(void*), + Buffer = (byte*)&value + }; + + QuicExceptionHelpers.ThrowIfFailed( + api.UnsafeSetParam(nativeObject, level, param, buffer), + "Could not set security configuration."); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSecurityConfig.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSecurityConfig.cs new file mode 100644 index 000000000000..58fc811f7cf4 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSecurityConfig.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + // TODO this will eventually be abstracted to support both Client and Server + // certificates + internal class MsQuicSecurityConfig : IDisposable + { + private bool _disposed; + private MsQuicApi _registration; + + public MsQuicSecurityConfig(MsQuicApi registration, IntPtr nativeObjPtr) + { + _registration = registration; + NativeObjPtr = nativeObjPtr; + } + + public IntPtr NativeObjPtr { get; private set; } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + _registration.SecConfigDeleteDelegate?.Invoke(NativeObjPtr); + NativeObjPtr = IntPtr.Zero; + _disposed = true; + } + + ~MsQuicSecurityConfig() + { + Dispose(disposing: false); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSession.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSession.cs new file mode 100644 index 000000000000..89dd99f73c3e --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/MsQuicSession.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal sealed class MsQuicSession : IDisposable + { + private bool _disposed = false; + private IntPtr _nativeObjPtr; + private bool _opened; + + internal MsQuicSession() + { + if (!MsQuicApi.IsQuicSupported) + { + throw new NotSupportedException(SR.net_quic_notsupported); + } + } + + public IntPtr ConnectionOpen(QuicClientConnectionOptions options) + { + if (!_opened) + { + OpenSession(options.ClientAuthenticationOptions.ApplicationProtocols[0].Protocol.ToArray(), + (ushort)options.MaxBidirectionalStreams, + (ushort)options.MaxUnidirectionalStreams); + } + + QuicExceptionHelpers.ThrowIfFailed(MsQuicApi.Api.ConnectionOpenDelegate( + _nativeObjPtr, + MsQuicConnection.NativeCallbackHandler, + IntPtr.Zero, + out IntPtr connectionPtr), + "Could not open the connection."); + + return connectionPtr; + } + + private void OpenSession(byte[] alpn, ushort bidirectionalStreamCount, ushort undirectionalStreamCount) + { + _opened = true; + _nativeObjPtr = MsQuicApi.Api.SessionOpen(alpn); + SetPeerBiDirectionalStreamCount(bidirectionalStreamCount); + SetPeerUnidirectionalStreamCount(undirectionalStreamCount); + } + + // TODO allow for a callback to select the certificate (SNI). + public IntPtr ListenerOpen(QuicListenerOptions options) + { + if (!_opened) + { + OpenSession(options.ServerAuthenticationOptions.ApplicationProtocols[0].Protocol.ToArray(), + (ushort)options.MaxBidirectionalStreams, + (ushort)options.MaxUnidirectionalStreams); + } + + QuicExceptionHelpers.ThrowIfFailed(MsQuicApi.Api.ListenerOpenDelegate( + _nativeObjPtr, + MsQuicListener.NativeCallbackHandler, + IntPtr.Zero, + out IntPtr listenerPointer), + "Could not open listener."); + + return listenerPointer; + } + + // TODO call this for graceful shutdown? + public void ShutDown( + QUIC_CONNECTION_SHUTDOWN_FLAG Flags, + ushort ErrorCode) + { + MsQuicApi.Api.SessionShutdownDelegate( + _nativeObjPtr, + (uint)Flags, + ErrorCode); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void SetPeerBiDirectionalStreamCount(ushort count) + { + SetUshortParamter(QUIC_PARAM_SESSION.PEER_BIDI_STREAM_COUNT, count); + } + + public void SetPeerUnidirectionalStreamCount(ushort count) + { + SetUshortParamter(QUIC_PARAM_SESSION.PEER_UNIDI_STREAM_COUNT, count); + } + + private unsafe void SetUshortParamter(QUIC_PARAM_SESSION param, ushort count) + { + var buffer = new MsQuicNativeMethods.QuicBuffer() + { + Length = sizeof(ushort), + Buffer = (byte*)&count + }; + + SetParam(param, buffer); + } + + public void SetDisconnectTimeout(TimeSpan timeout) + { + SetULongParamter(QUIC_PARAM_SESSION.DISCONNECT_TIMEOUT, (ulong)timeout.TotalMilliseconds); + } + + public void SetIdleTimeout(TimeSpan timeout) + { + SetULongParamter(QUIC_PARAM_SESSION.IDLE_TIMEOUT, (ulong)timeout.TotalMilliseconds); + + } + private unsafe void SetULongParamter(QUIC_PARAM_SESSION param, ulong count) + { + var buffer = new MsQuicNativeMethods.QuicBuffer() + { + Length = sizeof(ulong), + Buffer = (byte*)&count + }; + SetParam(param, buffer); + } + + private void SetParam( + QUIC_PARAM_SESSION param, + MsQuicNativeMethods.QuicBuffer buf) + { + QuicExceptionHelpers.ThrowIfFailed(MsQuicApi.Api.UnsafeSetParam( + _nativeObjPtr, + (uint)QUIC_PARAM_LEVEL.SESSION, + (uint)param, + buf), + "Could not set parameter on session."); + } + + ~MsQuicSession() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + MsQuicApi.Api.SessionCloseDelegate?.Invoke(_nativeObjPtr); + _nativeObjPtr = IntPtr.Zero; + + _disposed = true; + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs new file mode 100644 index 000000000000..1b8ab8ef2606 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal static class QuicExceptionHelpers + { + internal static void ThrowIfFailed(uint status, string message = null, Exception innerException = null) + { + if (!MsQuicStatusHelper.SuccessfulStatusCode(status)) + { + throw new QuicException($"{message} Error Code: {MsQuicStatusCodes.GetError(status)}"); + } + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs new file mode 100644 index 000000000000..1db5dc67b4c9 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs @@ -0,0 +1,81 @@ +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + /// + /// A resettable completion source which can be completed multiple times. + /// Used to make methods async between completed events and their associated async method. + /// + internal class ResettableCompletionSource : IValueTaskSource, IValueTaskSource + { + protected ManualResetValueTaskSourceCore _valueTaskSource; + + public ResettableCompletionSource() + { + _valueTaskSource.RunContinuationsAsynchronously = true; + } + + public ValueTask GetValueTask() + { + return new ValueTask(this, _valueTaskSource.Version); + } + + public ValueTask GetTypelessValueTask() + { + return new ValueTask(this, _valueTaskSource.Version); + } + + public ValueTaskSourceStatus GetStatus(short token) + { + return _valueTaskSource.GetStatus(token); + } + + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + _valueTaskSource.OnCompleted(continuation, state, token, flags); + } + + public void Complete(T result) + { + _valueTaskSource.SetResult(result); + } + + public void CompleteException(Exception ex) + { + _valueTaskSource.SetException(ex); + } + + public T GetResult(short token) + { + bool isValid = token == _valueTaskSource.Version; + try + { + return _valueTaskSource.GetResult(token); + } + finally + { + if (isValid) + { + _valueTaskSource.Reset(); + } + } + } + + void IValueTaskSource.GetResult(short token) + { + bool isValid = token == _valueTaskSource.Version; + try + { + _valueTaskSource.GetResult(token); + } + finally + { + if (isValid) + { + _valueTaskSource.Reset(); + } + } + } + } + } diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicConnection.cs new file mode 100644 index 000000000000..1d914c2668e7 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -0,0 +1,416 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Net.Security; +using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; + +namespace System.Net.Quic.Implementations.MsQuic +{ + internal sealed class MsQuicConnection : QuicConnectionProvider + { + private MsQuicSession _session; + + // Pointer to the underlying connection + // TODO replace all IntPtr with SafeHandles + private IntPtr _ptr; + + // Handle to this object for native callbacks. + private GCHandle _handle; + + // Delegate that wraps the static function that will be called when receiving an event. + // TODO investigate if the delegate can be static instead. + private ConnectionCallbackDelegate _connectionDelegate; + + // Endpoint to either connect to or the endpoint already accepted. + private IPEndPoint _localEndPoint; + private readonly IPEndPoint _remoteEndPoint; + + private readonly ResettableCompletionSource _connectTcs = new ResettableCompletionSource(); + private readonly ResettableCompletionSource _shutdownTcs = new ResettableCompletionSource(); + + private bool _disposed; + private bool _connected; + private MsQuicSecurityConfig _securityConfig; + private long _abortErrorCode = -1; + + // Queue for accepted streams + private readonly Channel _acceptQueue = Channel.CreateUnbounded(new UnboundedChannelOptions() + { + SingleReader = true, + SingleWriter = true, + }); + + // constructor for inbound connections + public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, IntPtr nativeObjPtr) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + _localEndPoint = localEndPoint; + _remoteEndPoint = remoteEndPoint; + _ptr = nativeObjPtr; + + SetCallbackHandler(); + SetIdleTimeout(TimeSpan.FromSeconds(120)); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + // constructor for outbound connections + public MsQuicConnection(QuicClientConnectionOptions options) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + // TODO need to figure out if/how we want to expose sessions + // Creating a session per connection isn't ideal. + _session = new MsQuicSession(); + _ptr = _session.ConnectionOpen(options); + _remoteEndPoint = options.RemoteEndPoint; + + SetCallbackHandler(); + SetIdleTimeout(options.IdleTimeout); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override IPEndPoint LocalEndPoint + { + get + { + return new IPEndPoint(_localEndPoint.Address, _localEndPoint.Port); + } + } + + internal async ValueTask SetSecurityConfigForConnection(X509Certificate cert, string certFilePath, string privateKeyFilePath) + { + _securityConfig = await MsQuicApi.Api.CreateSecurityConfig(cert, certFilePath, privateKeyFilePath); + // TODO this isn't being set correctly + MsQuicParameterHelpers.SetSecurityConfig(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.SEC_CONFIG, _securityConfig.NativeObjPtr); + } + + internal override IPEndPoint RemoteEndPoint => new IPEndPoint(_remoteEndPoint.Address, _remoteEndPoint.Port); + + internal override SslApplicationProtocol NegotiatedApplicationProtocol => throw new NotImplementedException(); + + internal override bool Connected => _connected; + + internal uint HandleEvent(ref ConnectionEvent connectionEvent) + { + uint status = MsQuicStatusCodes.Success; + try + { + switch (connectionEvent.Type) + { + // Connection is connected, can start to create streams. + case QUIC_CONNECTION_EVENT.CONNECTED: + { + status = HandleEventConnected( + connectionEvent); + } + break; + + // Connection is being closed by the transport + case QUIC_CONNECTION_EVENT.SHUTDOWN_INITIATED_BY_TRANSPORT: + { + status = HandleEventShutdownInitiatedByTransport( + connectionEvent); + } + break; + + // Connection is being closed by the peer + case QUIC_CONNECTION_EVENT.SHUTDOWN_INITIATED_BY_PEER: + { + status = HandleEventShutdownInitiatedByPeer( + connectionEvent); + } + break; + + // Connection has been shutdown + case QUIC_CONNECTION_EVENT.SHUTDOWN_COMPLETE: + { + status = HandleEventShutdownComplete( + connectionEvent); + } + break; + + case QUIC_CONNECTION_EVENT.PEER_STREAM_STARTED: + { + status = HandleEventNewStream( + connectionEvent); + } + break; + + case QUIC_CONNECTION_EVENT.STREAMS_AVAILABLE: + { + status = HandleEventStreamsAvailable( + connectionEvent); + } + break; + + default: + break; + } + } + catch (Exception) + { + // TODO we may want to either add a debug assert here or return specific error codes + // based on the exception caught. + return MsQuicStatusCodes.InternalError; + } + + return status; + } + + private uint HandleEventConnected(ConnectionEvent connectionEvent) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + SOCKADDR_INET inetAddress = MsQuicParameterHelpers.GetINetParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_ADDRESS); + _localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(inetAddress); + + _connected = true; + // I don't believe we need to lock here because + // handle event connected will not be called at the same time as + // handle event shutdown initiated by transport + _connectTcs.Complete(MsQuicStatusCodes.Success); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return MsQuicStatusCodes.Success; + } + + private uint HandleEventShutdownInitiatedByTransport(ConnectionEvent connectionEvent) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + if (!_connected) + { + _connectTcs.CompleteException(new IOException("Connection has been shutdown.")); + } + + _acceptQueue.Writer.Complete(); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventShutdownInitiatedByPeer(ConnectionEvent connectionEvent) + { + _abortErrorCode = connectionEvent.Data.ShutdownBeginPeer.ErrorCode; + _acceptQueue.Writer.Complete(); + return MsQuicStatusCodes.Success; + } + + private uint HandleEventShutdownComplete(ConnectionEvent connectionEvent) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + _shutdownTcs.Complete(MsQuicStatusCodes.Success); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return MsQuicStatusCodes.Success; + } + + private uint HandleEventNewStream(ConnectionEvent connectionEvent) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + MsQuicStream msQuicStream = new MsQuicStream(this, connectionEvent.StreamFlags, connectionEvent.Data.NewStream.Stream, inbound: true); + + _acceptQueue.Writer.TryWrite(msQuicStream); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventStreamsAvailable(ConnectionEvent connectionEvent) + { + return MsQuicStatusCodes.Success; + } + + internal override async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + MsQuicStream stream; + + try + { + stream = await _acceptQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException) + { + throw _abortErrorCode switch + { + -1 => new QuicOperationAbortedException(), // Shutdown initiated by us. + long err => new QuicConnectionAbortedException(err) // Shutdown initiated by peer. + }; + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return stream; + } + + internal override QuicStreamProvider OpenUnidirectionalStream() + { + ThrowIfDisposed(); + + return StreamOpen(QUIC_STREAM_OPEN_FLAG.UNIDIRECTIONAL); + } + + internal override QuicStreamProvider OpenBidirectionalStream() + { + ThrowIfDisposed(); + + return StreamOpen(QUIC_STREAM_OPEN_FLAG.NONE); + } + + internal override long GetRemoteAvailableUnidirectionalStreamCount() + { + return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.PEER_UNIDI_STREAM_COUNT); + } + + internal override long GetRemoteAvailableBidirectionalStreamCount() + { + return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.PEER_BIDI_STREAM_COUNT); + } + + private unsafe void SetIdleTimeout(TimeSpan timeout) + { + MsQuicParameterHelpers.SetULongParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.IDLE_TIMEOUT, (ulong)timeout.TotalMilliseconds); + } + + internal override ValueTask ConnectAsync(CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + QuicExceptionHelpers.ThrowIfFailed( + MsQuicApi.Api.ConnectionStartDelegate( + _ptr, + (ushort)_remoteEndPoint.AddressFamily, + _remoteEndPoint.Address.ToString(), + (ushort)_remoteEndPoint.Port), + "Failed to connect to peer."); + + return _connectTcs.GetTypelessValueTask(); + } + + private MsQuicStream StreamOpen( + QUIC_STREAM_OPEN_FLAG flags) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + IntPtr streamPtr = IntPtr.Zero; + QuicExceptionHelpers.ThrowIfFailed( + MsQuicApi.Api.StreamOpenDelegate( + _ptr, + (uint)flags, + MsQuicStream.NativeCallbackHandler, + IntPtr.Zero, + out streamPtr), + "Failed to open stream to peer."); + + MsQuicStream stream = new MsQuicStream(this, flags, streamPtr, inbound: false); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return stream; + } + + private void SetCallbackHandler() + { + _handle = GCHandle.Alloc(this); + _connectionDelegate = new ConnectionCallbackDelegate(NativeCallbackHandler); + MsQuicApi.Api.SetCallbackHandlerDelegate( + _ptr, + _connectionDelegate, + GCHandle.ToIntPtr(_handle)); + } + + private ValueTask ShutdownAsync( + QUIC_CONNECTION_SHUTDOWN_FLAG Flags, + long ErrorCode) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + uint status = MsQuicApi.Api.ConnectionShutdownDelegate( + _ptr, + (uint)Flags, + ErrorCode); + QuicExceptionHelpers.ThrowIfFailed(status, "Failed to shutdown connection."); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return _shutdownTcs.GetTypelessValueTask(); + } + + internal static uint NativeCallbackHandler( + IntPtr connection, + IntPtr context, + ref ConnectionEvent connectionEventStruct) + { + GCHandle handle = GCHandle.FromIntPtr(context); + MsQuicConnection quicConnection = (MsQuicConnection)handle.Target; + return quicConnection.HandleEvent(ref connectionEventStruct); + } + + public override void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + ~MsQuicConnection() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + if (_ptr != IntPtr.Zero) + { + MsQuicApi.Api.ConnectionCloseDelegate?.Invoke(_ptr); + } + + _ptr = IntPtr.Zero; + + if (disposing) + { + _handle.Free(); + _session?.Dispose(); + _securityConfig?.Dispose(); + } + + _disposed = true; + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + return ShutdownAsync(QUIC_CONNECTION_SHUTDOWN_FLAG.NONE, errorCode); + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(MsQuicStream)); + } + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicImplementationProvider.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicImplementationProvider.cs new file mode 100644 index 000000000000..55c5e524ec57 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicImplementationProvider.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Net.Security; + +namespace System.Net.Quic.Implementations.MsQuic +{ + internal sealed class MsQuicImplementationProvider : QuicImplementationProvider + { + internal override QuicListenerProvider CreateListener(QuicListenerOptions options) + { + return new MsQuicListener(options); + } + + internal override QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options) + { + return new MsQuicConnection(options); + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicListener.cs new file mode 100644 index 000000000000..14323d963f3e --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicListener.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Net.Security; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; + +namespace System.Net.Quic.Implementations.MsQuic +{ + internal sealed class MsQuicListener : QuicListenerProvider, IDisposable + { + // Security configuration for MsQuic + private MsQuicSession _session; + + // Pointer to the underlying listener + // TODO replace all IntPtr with SafeHandles + private IntPtr _ptr; + + // Handle to this object for native callbacks. + private GCHandle _handle; + + // Delegate that wraps the static function that will be called when receiving an event. + private ListenerCallbackDelegate _listenerDelegate; + + // Ssl listening options (ALPN, cert, etc) + private SslServerAuthenticationOptions _sslOptions; + + private QuicListenerOptions _options; + private volatile bool _disposed; + private IPEndPoint _listenEndPoint; + + private readonly Channel _acceptConnectionQueue; + + internal MsQuicListener(QuicListenerOptions options) + { + _session = new MsQuicSession(); + _acceptConnectionQueue = Channel.CreateBounded(new BoundedChannelOptions(options.ListenBacklog) + { + SingleReader = true, + SingleWriter = true + }); + + _options = options; + _sslOptions = options.ServerAuthenticationOptions; + _listenEndPoint = options.ListenEndPoint; + + _ptr = _session.ListenerOpen(options); + } + + internal override IPEndPoint ListenEndPoint + { + get + { + return new IPEndPoint(_listenEndPoint.Address, _listenEndPoint.Port); + } + } + + internal override async ValueTask AcceptConnectionAsync(CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + MsQuicConnection connection; + + try + { + connection = await _acceptConnectionQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException) + { + throw new QuicOperationAbortedException(); + } + + await connection.SetSecurityConfigForConnection(_sslOptions.ServerCertificate, + _options.CertificateFilePath, + _options.PrivateKeyFilePath); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return connection; + } + + public override void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + ~MsQuicListener() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + StopAcceptingConnections(); + + if (_ptr != IntPtr.Zero) + { + MsQuicApi.Api.ListenerStopDelegate(_ptr); + MsQuicApi.Api.ListenerCloseDelegate(_ptr); + } + + _ptr = IntPtr.Zero; + + // TODO this call to session dispose hangs. + //_session.Dispose(); + _disposed = true; + } + + internal override void Start() + { + ThrowIfDisposed(); + + SetCallbackHandler(); + + SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet(_listenEndPoint); + + QuicExceptionHelpers.ThrowIfFailed(MsQuicApi.Api.ListenerStartDelegate( + _ptr, + ref address), + "Failed to start listener."); + + SetListenPort(); + } + + internal override void Close() + { + ThrowIfDisposed(); + + MsQuicApi.Api.ListenerStopDelegate(_ptr); + } + + private unsafe void SetListenPort() + { + SOCKADDR_INET inetAddress = MsQuicParameterHelpers.GetINetParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.LISTENER, (uint)QUIC_PARAM_LISTENER.LOCAL_ADDRESS); + + _listenEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(inetAddress); + } + + internal unsafe uint ListenerCallbackHandler( + ref ListenerEvent evt) + { + try + { + switch (evt.Type) + { + case QUIC_LISTENER_EVENT.NEW_CONNECTION: + { + NewConnectionInfo connectionInfo = *(NewConnectionInfo*)evt.Data.NewConnection.Info; + IPEndPoint localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(*(SOCKADDR_INET*)connectionInfo.LocalAddress); + IPEndPoint remoteEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(*(SOCKADDR_INET*)connectionInfo.RemoteAddress); + MsQuicConnection msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, evt.Data.NewConnection.Connection); + _acceptConnectionQueue.Writer.TryWrite(msQuicConnection); + } + // Always pend the new connection to wait for the security config to be resolved + // TODO this doesn't need to be async always + return MsQuicStatusCodes.Pending; + default: + return MsQuicStatusCodes.InternalError; + } + } + catch (Exception) + { + return MsQuicStatusCodes.InternalError; + } + } + + private void StopAcceptingConnections() + { + _acceptConnectionQueue.Writer.TryComplete(); + } + + internal static uint NativeCallbackHandler( + IntPtr listener, + IntPtr context, + ref ListenerEvent connectionEventStruct) + { + GCHandle handle = GCHandle.FromIntPtr(context); + MsQuicListener quicListener = (MsQuicListener)handle.Target; + + return quicListener.ListenerCallbackHandler(ref connectionEventStruct); + } + + internal void SetCallbackHandler() + { + _handle = GCHandle.Alloc(this); + _listenerDelegate = new ListenerCallbackDelegate(NativeCallbackHandler); + MsQuicApi.Api.SetCallbackHandlerDelegate( + _ptr, + _listenerDelegate, + GCHandle.ToIntPtr(_handle)); + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(MsQuicStream)); + } + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicStream.cs new file mode 100644 index 000000000000..00ca779e3ce7 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -0,0 +1,1042 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; + +namespace System.Net.Quic.Implementations.MsQuic +{ + internal sealed class MsQuicStream : QuicStreamProvider + { + // Pointer to the underlying stream + // TODO replace all IntPtr with SafeHandles + private readonly IntPtr _ptr; + + // Handle to this object for native callbacks. + private GCHandle _handle; + + // Delegate that wraps the static function that will be called when receiving an event. + private StreamCallbackDelegate _callback; + + // Backing for StreamId + private long _streamId = -1; + + // Resettable completions to be used for multiple calls to send, start, and shutdown. + private readonly ResettableCompletionSource _sendResettableCompletionSource; + + // Resettable completions to be used for multiple calls to receive. + private readonly ResettableCompletionSource _receiveResettableCompletionSource; + + private readonly ResettableCompletionSource _shutdownWriteResettableCompletionSource; + + // Buffers to hold during a call to send. + private MemoryHandle[] _bufferArrays = new MemoryHandle[1]; + private QuicBuffer[] _sendQuicBuffers = new QuicBuffer[1]; + + // Handle to hold when sending. + private GCHandle _sendHandle; + + // Used to check if StartAsync has been called. + private bool _started; + + private ReadState _readState; + private long _readErrorCode = -1; + + private ShutdownWriteState _shutdownState; + + private SendState _sendState; + private long _sendErrorCode = -1; + + // Used by the class to indicate that the stream is m_Readable. + private readonly bool _canRead; + + // Used by the class to indicate that the stream is writable. + private readonly bool _canWrite; + + private volatile bool _disposed = false; + + private List _receiveQuicBuffers = new List(); + + // TODO consider using Interlocked.Exchange instead of a sync if we can avoid it. + private object _sync = new object(); + + // Creates a new MsQuicStream + internal MsQuicStream(MsQuicConnection connection, QUIC_STREAM_OPEN_FLAG flags, IntPtr nativeObjPtr, bool inbound) + { + Debug.Assert(connection != null); + + _ptr = nativeObjPtr; + + _sendResettableCompletionSource = new ResettableCompletionSource(); + _receiveResettableCompletionSource = new ResettableCompletionSource(); + _shutdownWriteResettableCompletionSource = new ResettableCompletionSource(); + SetCallbackHandler(); + + if (inbound) + { + _started = true; + _canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAG.UNIDIRECTIONAL); + _canRead = true; + } + else + { + _canWrite = true; + _canRead = !flags.HasFlag(QUIC_STREAM_OPEN_FLAG.UNIDIRECTIONAL); + StartWrites(); + } + } + + internal override bool CanRead => _canRead; + + internal override bool CanWrite => _canWrite; + + internal override long StreamId + { + get + { + ThrowIfDisposed(); + + if (_streamId == -1) + { + _streamId = GetStreamId(); + } + + return _streamId; + } + } + + internal override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return WriteAsync(buffer, endStream: false, cancellationToken); + } + + internal override ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default) + { + return WriteAsync(buffers, endStream: false, cancellationToken); + } + + internal override async ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken); + + await SendReadOnlySequenceAsync(buffers, endStream ? QUIC_SEND_FLAG.FIN : QUIC_SEND_FLAG.NONE); + + HandleWriteCompletedState(); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override ValueTask WriteAsync(ReadOnlyMemory> buffers, CancellationToken cancellationToken = default) + { + return WriteAsync(buffers, endStream: false, cancellationToken); + } + + internal override async ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken); + + await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAG.FIN : QUIC_SEND_FLAG.NONE); + + HandleWriteCompletedState(); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken); + + await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAG.FIN : QUIC_SEND_FLAG.NONE); + + HandleWriteCompletedState(); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + private async ValueTask HandleWriteStartState(CancellationToken cancellationToken) + { + if (!_canWrite) + { + throw new InvalidOperationException("Writing is not allowed on stream."); + } + + lock (_sync) + { + if (_sendState == SendState.Aborted) + { + throw new OperationCanceledException("Sending has already been aborted on the stream"); + } + } + + CancellationTokenRegistration registration = cancellationToken.Register(() => + { + bool shouldComplete = false; + lock (_sync) + { + if (_sendState == SendState.None) + { + _sendState = SendState.Aborted; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _sendResettableCompletionSource.CompleteException(new OperationCanceledException("Write was canceled", cancellationToken)); + } + }); + + // Make sure start has completed + if (!_started) + { + await _sendResettableCompletionSource.GetTypelessValueTask(); + _started = true; + } + + return registration; + } + + private void HandleWriteCompletedState() + { + lock (_sync) + { + if (_sendState == SendState.Finished) + { + _sendState = SendState.None; + } + } + } + + internal override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + if (!_canRead) + { + throw new InvalidOperationException("Reading is not allowed on stream."); + } + + lock (_sync) + { + if (_readState == ReadState.ReadsCompleted) + { + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + return 0; + } + else if (_readState == ReadState.Aborted) + { + throw _readErrorCode switch + { + -1 => new QuicOperationAbortedException(), + long err => new QuicStreamAbortedException(err) + }; + } + } + + using CancellationTokenRegistration registration = cancellationToken.Register(() => + { + bool shouldComplete = false; + lock (_sync) + { + if (_readState == ReadState.None) + { + shouldComplete = true; + } + + _readState = ReadState.Aborted; + } + + if (shouldComplete) + { + _receiveResettableCompletionSource.CompleteException(new OperationCanceledException("Read was canceled", cancellationToken)); + } + }); + + // TODO there could potentially be a perf gain by storing the buffer from the inital read + // This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers + // longer than it needs to. We will need to benchmark this. + int length = (int)await _receiveResettableCompletionSource.GetValueTask(); + + int actual = Math.Min(length, destination.Length); + + static unsafe void CopyToBuffer(Span destinationBuffer, List sourceBuffers) + { + Span slicedBuffer = destinationBuffer; + for (int i = 0; i < sourceBuffers.Count; i++) + { + QuicBuffer nativeBuffer = sourceBuffers[i]; + int length = Math.Min((int)nativeBuffer.Length, slicedBuffer.Length); + new Span(nativeBuffer.Buffer, length).CopyTo(slicedBuffer); + if (length < nativeBuffer.Length) + { + // The buffer passed in was larger that the received data, return + return; + } + slicedBuffer = slicedBuffer.Slice(length); + } + } + + CopyToBuffer(destination.Span, _receiveQuicBuffers); + + lock (_sync) + { + if (_readState == ReadState.IndividualReadComplete) + { + _receiveQuicBuffers.Clear(); + ReceiveComplete(actual); + EnableReceive(); + _readState = ReadState.None; + } + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return actual; + } + + // TODO do we want this to be a synchronization mechanism to cancel a pending read + // If so, we need to complete the read here as well. + internal override void AbortRead(long errorCode) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + lock (_sync) + { + _readState = ReadState.Aborted; + } + + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.ABORT_RECV, errorCode); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override void AbortWrite(long errorCode) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + bool shouldComplete = false; + + lock (_sync) + { + if (_shutdownState == ShutdownWriteState.None) + { + _shutdownState = ShutdownWriteState.Canceled; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _shutdownWriteResettableCompletionSource.CompleteException(new QuicStreamAbortedException("Shutdown was aborted.", errorCode)); + } + + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.ABORT_SEND, errorCode); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + } + + internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + ThrowIfDisposed(); + + // TODO do anything to stop writes? + using CancellationTokenRegistration registration = cancellationToken.Register(() => + { + bool shouldComplete = false; + lock (_sync) + { + if (_shutdownState == ShutdownWriteState.None) + { + _shutdownState = ShutdownWriteState.Canceled; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _shutdownWriteResettableCompletionSource.CompleteException(new OperationCanceledException("Shutdown was canceled", cancellationToken)); + } + }); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return _shutdownWriteResettableCompletionSource.GetTypelessValueTask(); + } + + internal override void Shutdown() + { + ThrowIfDisposed(); + + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.GRACEFUL, errorCode: 0); + } + + // TODO consider removing sync-over-async with blocking calls. + internal override int Read(Span buffer) + { + ThrowIfDisposed(); + + return ReadAsync(buffer.ToArray()).AsTask().GetAwaiter().GetResult(); + } + + internal override void Write(ReadOnlySpan buffer) + { + ThrowIfDisposed(); + + // TODO: optimize this. + WriteAsync(buffer.ToArray()).AsTask().GetAwaiter().GetResult(); + } + + // MsQuic doesn't support explicit flushing + internal override void Flush() + { + ThrowIfDisposed(); + } + + // MsQuic doesn't support explicit flushing + internal override Task FlushAsync(CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + return default; + } + + public override ValueTask DisposeAsync() + { + if (_disposed) + { + return default; + } + + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + CleanupSendState(); + + if (_ptr != IntPtr.Zero) + { + // TODO resolve graceful vs abortive dispose here. Will file a separate issue. + //MsQuicApi.Api._streamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.ABORT, 1); + MsQuicApi.Api.StreamCloseDelegate?.Invoke(_ptr); + } + + _handle.Free(); + + _disposed = true; + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return default; + } + + public override void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + ~MsQuicStream() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + CleanupSendState(); + + if (_ptr != IntPtr.Zero) + { + // TODO resolve graceful vs abortive dispose here. Will file a separate issue. + //MsQuicApi.Api._streamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.ABORT, 1); + MsQuicApi.Api.StreamCloseDelegate?.Invoke(_ptr); + } + + _handle.Free(); + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + _disposed = true; + } + + private void EnableReceive() + { + MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_ptr, enabled: true); + } + + internal static uint NativeCallbackHandler( + IntPtr stream, + IntPtr context, + ref StreamEvent streamEvent) + { + var handle = GCHandle.FromIntPtr(context); + var quicStream = (MsQuicStream)handle.Target; + + return quicStream.HandleEvent(ref streamEvent); + } + + private uint HandleEvent(ref StreamEvent evt) + { + uint status = MsQuicStatusCodes.Success; + + try + { + switch (evt.Type) + { + // Stream has started. + // Will only be done for outbound streams (inbound streams have already started) + case QUIC_STREAM_EVENT.START_COMPLETE: + status = HandleStartComplete(); + break; + // Received data on the stream + case QUIC_STREAM_EVENT.RECEIVE: + { + status = HandleEventRecv(ref evt); + } + break; + // Send has completed. + // Contains a canceled bool to indicate if the send was canceled. + case QUIC_STREAM_EVENT.SEND_COMPLETE: + { + status = HandleEventSendComplete(ref evt); + } + break; + // Peer has told us to shutdown the reading side of the stream. + case QUIC_STREAM_EVENT.PEER_SEND_SHUTDOWN: + { + status = HandleEventPeerSendShutdown(); + } + break; + // Peer has told us to abort the reading side of the stream. + case QUIC_STREAM_EVENT.PEER_SEND_ABORTED: + { + status = HandleEventPeerSendAborted(ref evt); + } + break; + // Peer has stopped receiving data, don't send anymore. + case QUIC_STREAM_EVENT.PEER_RECEIVE_ABORTED: + { + status = HandleEventPeerRecvAborted(ref evt); + } + break; + // Occurs when shutdown is completed for the send side. + // This only happens for shutdown on sending, not receiving + // Receive shutdown can only be abortive. + case QUIC_STREAM_EVENT.SEND_SHUTDOWN_COMPLETE: + { + status = HandleEventSendShutdownComplete(ref evt); + } + break; + // Shutdown for both sending and receiving is completed. + case QUIC_STREAM_EVENT.SHUTDOWN_COMPLETE: + { + status = HandleEventShutdownComplete(); + } + break; + default: + break; + } + } + catch (Exception) + { + return MsQuicStatusCodes.InternalError; + } + + return status; + } + + private unsafe uint HandleEventRecv(ref MsQuicNativeMethods.StreamEvent evt) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + StreamEventDataRecv receieveEvent = evt.Data.Recv; + for (int i = 0; i < receieveEvent.BufferCount; i++) + { + _receiveQuicBuffers.Add(receieveEvent.Buffers[i]); + } + + bool shouldComplete = false; + lock (_sync) + { + if (_readState == ReadState.None) + { + shouldComplete = true; + } + _readState = ReadState.IndividualReadComplete; + } + + if (shouldComplete) + { + _receiveResettableCompletionSource.Complete((uint)receieveEvent.TotalBufferLength); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Pending; + } + + private uint HandleEventPeerRecvAborted(ref StreamEvent evt) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + bool shouldComplete = false; + lock (_sync) + { + if (_sendState == SendState.None) + { + shouldComplete = true; + } + _sendState = SendState.Aborted; + _sendErrorCode = evt.Data.PeerSendAbort.ErrorCode; + } + + if (shouldComplete) + { + _sendResettableCompletionSource.CompleteException(new QuicStreamAbortedException(_sendErrorCode)); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleStartComplete() + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + bool shouldComplete = false; + lock (_sync) + { + // Check send state before completing as send cancellation is shared between start and send. + if (_sendState == SendState.None) + { + shouldComplete = true; + } + } + + if (shouldComplete) + { + _sendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventSendShutdownComplete(ref MsQuicNativeMethods.StreamEvent evt) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + bool shouldComplete = false; + lock (_sync) + { + if (_shutdownState == ShutdownWriteState.None) + { + _shutdownState = ShutdownWriteState.Finished; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _shutdownWriteResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventShutdownComplete() + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + bool shouldReadComplete = false; + bool shouldShutdownWriteComplete = false; + + lock (_sync) + { + // This event won't occur within the middle of a receive. + if (NetEventSource.IsEnabled) NetEventSource.Info("Completing resettable event source."); + + if (_readState == ReadState.None) + { + shouldReadComplete = true; + } + + _readState = ReadState.ReadsCompleted; + + if (_shutdownState == ShutdownWriteState.None) + { + _shutdownState = ShutdownWriteState.Finished; + shouldShutdownWriteComplete = true; + } + } + + if (shouldReadComplete) + { + _receiveResettableCompletionSource.Complete(0); + } + + if (shouldShutdownWriteComplete) + { + _shutdownWriteResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventPeerSendAborted(ref StreamEvent evt) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + bool shouldComplete = false; + lock (_sync) + { + if (_readState == ReadState.None) + { + shouldComplete = true; + } + _readState = ReadState.Aborted; + _readErrorCode = evt.Data.PeerSendAbort.ErrorCode; + } + + if (shouldComplete) + { + _receiveResettableCompletionSource.CompleteException(new QuicStreamAbortedException(_readErrorCode)); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventPeerSendShutdown() + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + bool shouldComplete = false; + + lock (_sync) + { + // This event won't occur within the middle of a receive. + if (NetEventSource.IsEnabled) NetEventSource.Info("Completing resettable event source."); + + if (_readState == ReadState.None) + { + shouldComplete = true; + } + + _readState = ReadState.ReadsCompleted; + } + + if (shouldComplete) + { + _receiveResettableCompletionSource.Complete(0); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private uint HandleEventSendComplete(ref StreamEvent evt) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + + CleanupSendState(); + + // TODO throw if a write was canceled. + uint errorCode = evt.Data.SendComplete.Canceled; + + bool shouldComplete = false; + lock (_sync) + { + if (_sendState == SendState.None) + { + _sendState = SendState.Finished; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _sendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } + + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + + return MsQuicStatusCodes.Success; + } + + private void CleanupSendState() + { + if (_sendHandle.IsAllocated) + { + _sendHandle.Free(); + } + + // Callings dispose twice on a memory handle should be okay + foreach (MemoryHandle buffer in _bufferArrays) + { + buffer.Dispose(); + } + } + + private void SetCallbackHandler() + { + _handle = GCHandle.Alloc(this); + + _callback = new StreamCallbackDelegate(NativeCallbackHandler); + MsQuicApi.Api.SetCallbackHandlerDelegate( + _ptr, + _callback, + GCHandle.ToIntPtr(_handle)); + } + + // TODO prevent overlapping sends or consider supporting it. + private unsafe ValueTask SendReadOnlyMemoryAsync( + ReadOnlyMemory buffer, + QUIC_SEND_FLAG flags) + { + if (buffer.IsEmpty) + { + if ((flags & QUIC_SEND_FLAG.FIN) == QUIC_SEND_FLAG.FIN) + { + // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.GRACEFUL, errorCode: 0); + } + return default; + } + + MemoryHandle handle = buffer.Pin(); + _sendQuicBuffers[0].Length = (uint)buffer.Length; + _sendQuicBuffers[0].Buffer = (byte*)handle.Pointer; + + _bufferArrays[0] = handle; + + _sendHandle = GCHandle.Alloc(_sendQuicBuffers, GCHandleType.Pinned); + + var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_sendQuicBuffers, 0); + + uint status = MsQuicApi.Api.StreamSendDelegate( + _ptr, + quicBufferPointer, + bufferCount: 1, + (uint)flags, + _ptr); + + if (!MsQuicStatusHelper.SuccessfulStatusCode(status)) + { + CleanupSendState(); + + // TODO this may need to be an aborted exception. + QuicExceptionHelpers.ThrowIfFailed(status, + "Could not send data to peer."); + } + + return _sendResettableCompletionSource.GetTypelessValueTask(); + } + + private unsafe ValueTask SendReadOnlySequenceAsync( + ReadOnlySequence buffers, + QUIC_SEND_FLAG flags) + { + if (buffers.IsEmpty) + { + if ((flags & QUIC_SEND_FLAG.FIN) == QUIC_SEND_FLAG.FIN) + { + // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.GRACEFUL, errorCode: 0); + } + return default; + } + + uint count = 0; + + foreach (ReadOnlyMemory buffer in buffers) + { + ++count; + } + + if (_sendQuicBuffers.Length < count) + { + _sendQuicBuffers = new QuicBuffer[count]; + _bufferArrays = new MemoryHandle[count]; + } + + count = 0; + + foreach (ReadOnlyMemory buffer in buffers) + { + MemoryHandle handle = buffer.Pin(); + _sendQuicBuffers[count].Length = (uint)buffer.Length; + _sendQuicBuffers[count].Buffer = (byte*)handle.Pointer; + _bufferArrays[count] = handle; + ++count; + } + + _sendHandle = GCHandle.Alloc(_sendQuicBuffers, GCHandleType.Pinned); + + var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_sendQuicBuffers, 0); + + uint status = MsQuicApi.Api.StreamSendDelegate( + _ptr, + quicBufferPointer, + count, + (uint)flags, + _ptr); + + if (!MsQuicStatusHelper.SuccessfulStatusCode(status)) + { + CleanupSendState(); + + // TODO this may need to be an aborted exception. + QuicExceptionHelpers.ThrowIfFailed(status, + "Could not send data to peer."); + } + + return _sendResettableCompletionSource.GetTypelessValueTask(); + } + + private unsafe ValueTask SendReadOnlyMemoryListAsync( + ReadOnlyMemory> buffers, + QUIC_SEND_FLAG flags) + { + if (buffers.IsEmpty) + { + if ((flags & QUIC_SEND_FLAG.FIN) == QUIC_SEND_FLAG.FIN) + { + // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. + MsQuicApi.Api.StreamShutdownDelegate(_ptr, (uint)QUIC_STREAM_SHUTDOWN_FLAG.GRACEFUL, errorCode: 0); + } + return default; + } + + ReadOnlyMemory[] array = buffers.ToArray(); + + uint length = (uint)array.Length; + + if (_sendQuicBuffers.Length < length) + { + _sendQuicBuffers = new QuicBuffer[length]; + _bufferArrays = new MemoryHandle[length]; + } + + for (int i = 0; i < length; i++) + { + ReadOnlyMemory buffer = array[i]; + MemoryHandle handle = buffer.Pin(); + _sendQuicBuffers[i].Length = (uint)buffer.Length; + _sendQuicBuffers[i].Buffer = (byte*)handle.Pointer; + _bufferArrays[i] = handle; + } + + _sendHandle = GCHandle.Alloc(_sendQuicBuffers, GCHandleType.Pinned); + + var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_sendQuicBuffers, 0); + + uint status = MsQuicApi.Api.StreamSendDelegate( + _ptr, + quicBufferPointer, + length, + (uint)flags, + _ptr); + + if (!MsQuicStatusHelper.SuccessfulStatusCode(status)) + { + CleanupSendState(); + + // TODO this may need to be an aborted exception. + QuicExceptionHelpers.ThrowIfFailed(status, + "Could not send data to peer."); + } + + return _sendResettableCompletionSource.GetTypelessValueTask(); + } + + private void StartWrites() + { + Debug.Assert(!_started); + uint status = MsQuicApi.Api.StreamStartDelegate( + _ptr, + (uint)QUIC_STREAM_START_FLAG.ASYNC); + + QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream."); + } + + private void ReceiveComplete(int bufferLength) + { + uint status = MsQuicApi.Api.StreamReceiveCompleteDelegate(_ptr, (ulong)bufferLength); + QuicExceptionHelpers.ThrowIfFailed(status, "Could not complete receive call."); + } + + // This can fail if the stream isn't started. + private unsafe long GetStreamId() + { + return (long)MsQuicParameterHelpers.GetULongParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.STREAM, (uint)QUIC_PARAM_STREAM.ID); + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(MsQuicStream)); + } + } + + private enum ReadState + { + None, + IndividualReadComplete, + ReadsCompleted, + Aborted + } + + private enum ShutdownWriteState + { + None, + Canceled, + Finished + } + + private enum SendState + { + None, + Aborted, + Finished + } + } +} diff --git a/src/Shared/runtime/Quic/Implementations/QuicConnectionProvider.cs b/src/Shared/runtime/Quic/Implementations/QuicConnectionProvider.cs new file mode 100644 index 000000000000..d77bf1df76fb --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/QuicConnectionProvider.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations +{ + internal abstract class QuicConnectionProvider : IDisposable + { + internal abstract bool Connected { get; } + + internal abstract IPEndPoint LocalEndPoint { get; } + + internal abstract IPEndPoint RemoteEndPoint { get; } + + internal abstract ValueTask ConnectAsync(CancellationToken cancellationToken = default); + + internal abstract QuicStreamProvider OpenUnidirectionalStream(); + + internal abstract QuicStreamProvider OpenBidirectionalStream(); + + internal abstract long GetRemoteAvailableUnidirectionalStreamCount(); + + internal abstract long GetRemoteAvailableBidirectionalStreamCount(); + + internal abstract ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default); + + internal abstract System.Net.Security.SslApplicationProtocol NegotiatedApplicationProtocol { get; } + + internal abstract ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default); + + public abstract void Dispose(); + } +} diff --git a/src/Shared/runtime/Quic/Implementations/QuicImplementationProvider.cs b/src/Shared/runtime/Quic/Implementations/QuicImplementationProvider.cs new file mode 100644 index 000000000000..906f4562667c --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/QuicImplementationProvider.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Security; + +namespace System.Net.Quic.Implementations +{ + internal abstract class QuicImplementationProvider + { + internal QuicImplementationProvider() { } + + internal abstract QuicListenerProvider CreateListener(QuicListenerOptions options); + + internal abstract QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options); + } +} diff --git a/src/Shared/runtime/Quic/Implementations/QuicListenerProvider.cs b/src/Shared/runtime/Quic/Implementations/QuicListenerProvider.cs new file mode 100644 index 000000000000..f533a8bb3818 --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/QuicListenerProvider.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations +{ + internal abstract class QuicListenerProvider : IDisposable + { + internal abstract IPEndPoint ListenEndPoint { get; } + + internal abstract ValueTask AcceptConnectionAsync(CancellationToken cancellationToken = default); + + internal abstract void Start(); + + internal abstract void Close(); + + public abstract void Dispose(); + } +} diff --git a/src/Shared/runtime/Quic/Implementations/QuicStreamProvider.cs b/src/Shared/runtime/Quic/Implementations/QuicStreamProvider.cs new file mode 100644 index 000000000000..1e96e1597acf --- /dev/null +++ b/src/Shared/runtime/Quic/Implementations/QuicStreamProvider.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic.Implementations +{ + internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable + { + internal abstract long StreamId { get; } + + internal abstract bool CanRead { get; } + + internal abstract int Read(Span buffer); + + internal abstract ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default); + + internal abstract void AbortRead(long errorCode); + + internal abstract void AbortWrite(long errorCode); + + internal abstract bool CanWrite { get; } + + internal abstract void Write(ReadOnlySpan buffer); + + internal abstract ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default); + + internal abstract ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default); + + internal abstract ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default); + + internal abstract ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default); + + internal abstract ValueTask WriteAsync(ReadOnlyMemory> buffers, CancellationToken cancellationToken = default); + + internal abstract ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default); + + internal abstract ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default); + + internal abstract void Shutdown(); + + internal abstract void Flush(); + + internal abstract Task FlushAsync(CancellationToken cancellationToken); + + public abstract void Dispose(); + + public abstract ValueTask DisposeAsync(); + } +} diff --git a/src/Shared/runtime/Quic/Interop/Interop.MsQuic.cs b/src/Shared/runtime/Quic/Interop/Interop.MsQuic.cs new file mode 100644 index 000000000000..25a5e5775511 --- /dev/null +++ b/src/Shared/runtime/Quic/Interop/Interop.MsQuic.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static class MsQuic + { + [DllImport(Libraries.MsQuic)] + internal static unsafe extern uint MsQuicOpen(int version, out MsQuicNativeMethods.NativeApi* registration); + } +} diff --git a/src/Shared/runtime/Quic/Interop/MsQuicEnums.cs b/src/Shared/runtime/Quic/Interop/MsQuicEnums.cs new file mode 100644 index 000000000000..3d294bce9c80 --- /dev/null +++ b/src/Shared/runtime/Quic/Interop/MsQuicEnums.cs @@ -0,0 +1,167 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + /// + /// Flags to pass when creating a security config. + /// + [Flags] + internal enum QUIC_SEC_CONFIG_FLAG : uint + { + NONE = 0, + CERT_HASH = 0x00000001, + CERT_HASH_STORE = 0x00000002, + CERT_CONTEXT = 0x00000004, + CERT_FILE = 0x00000008, + ENABL_OCSP = 0x00000010, + CERT_NULL = 0xF0000000, + } + + [Flags] + internal enum QUIC_CONNECTION_SHUTDOWN_FLAG : uint + { + NONE = 0x0, + SILENT = 0x1 + } + + [Flags] + internal enum QUIC_STREAM_OPEN_FLAG : uint + { + NONE = 0, + UNIDIRECTIONAL = 0x1, + ZERO_RTT = 0x2, + } + + [Flags] + internal enum QUIC_STREAM_START_FLAG : uint + { + NONE = 0, + FAIL_BLOCKED = 0x1, + IMMEDIATE = 0x2, + ASYNC = 0x4, + } + + [Flags] + internal enum QUIC_STREAM_SHUTDOWN_FLAG : uint + { + NONE = 0, + GRACEFUL = 0x1, + ABORT_SEND = 0x2, + ABORT_RECV = 0x4, + ABORT = ABORT_SEND | ABORT_RECV, + IMMEDIATE = 0x8 + } + + [Flags] + internal enum QUIC_RECEIVE_FLAG : uint + { + NONE = 0, + ZERO_RTT = 0x1, + FIN = 0x02 + } + + [Flags] + internal enum QUIC_SEND_FLAG : uint + { + NONE = 0, + ALLOW_0_RTT = 0x00000001, + FIN = 0x00000002, + } + + internal enum QUIC_PARAM_LEVEL : uint + { + REGISTRATION = 0, + SESSION = 1, + LISTENER = 2, + CONNECTION = 3, + TLS = 4, + STREAM = 5, + } + + internal enum QUIC_PARAM_REGISTRATION : uint + { + RETRY_MEMORY_PERCENT = 0, + CID_PREFIX = 1 + } + + internal enum QUIC_PARAM_SESSION : uint + { + TLS_TICKET_KEY = 0, + PEER_BIDI_STREAM_COUNT = 1, + PEER_UNIDI_STREAM_COUNT = 2, + IDLE_TIMEOUT = 3, + DISCONNECT_TIMEOUT = 4, + MAX_BYTES_PER_KEY = 5 + } + + internal enum QUIC_PARAM_LISTENER : uint + { + LOCAL_ADDRESS = 0, + STATS = 1 + } + + internal enum QUIC_PARAM_CONN : uint + { + QUIC_VERSION = 0, + LOCAL_ADDRESS = 1, + REMOTE_ADDRESS = 2, + IDLE_TIMEOUT = 3, + PEER_BIDI_STREAM_COUNT = 4, + PEER_UNIDI_STREAM_COUNT = 5, + LOCAL_BIDI_STREAM_COUNT = 6, + LOCAL_UNIDI_STREAM_COUNT = 7, + CLOSE_REASON_PHRASE = 8, + STATISTICS = 9, + STATISTICS_PLAT = 10, + CERT_VALIDATION_FLAGS = 11, + KEEP_ALIVE = 12, + DISCONNECT_TIMEOUT = 13, + SEC_CONFIG = 14, + SEND_BUFFERING = 15, + SEND_PACING = 16, + SHARE_UDP_BINDING = 17, + IDEAL_PROCESSOR = 18, + MAX_STREAM_IDS = 19 + } + + internal enum QUIC_PARAM_STREAM : uint + { + ID = 0, + ZERORTT_LENGTH = 1, + IDEAL_SEND_BUFFER = 2 + } + + internal enum QUIC_LISTENER_EVENT : uint + { + NEW_CONNECTION = 0 + } + + internal enum QUIC_CONNECTION_EVENT : uint + { + CONNECTED = 0, + SHUTDOWN_INITIATED_BY_TRANSPORT = 1, + SHUTDOWN_INITIATED_BY_PEER = 2, + SHUTDOWN_COMPLETE = 3, + LOCAL_ADDRESS_CHANGED = 4, + PEER_ADDRESS_CHANGED = 5, + PEER_STREAM_STARTED = 6, + STREAMS_AVAILABLE = 7, + PEER_NEEDS_STREAMS = 8, + IDEAL_PROCESSOR_CHANGED = 9, + } + + internal enum QUIC_STREAM_EVENT : uint + { + START_COMPLETE = 0, + RECEIVE = 1, + SEND_COMPLETE = 2, + PEER_SEND_SHUTDOWN = 3, + PEER_SEND_ABORTED = 4, + PEER_RECEIVE_ABORTED = 5, + SEND_SHUTDOWN_COMPLETE = 6, + SHUTDOWN_COMPLETE = 7, + IDEAL_SEND_BUFFER_SIZE = 8, + } +} diff --git a/src/Shared/runtime/Quic/Interop/MsQuicNativeMethods.cs b/src/Shared/runtime/Quic/Interop/MsQuicNativeMethods.cs new file mode 100644 index 000000000000..aca6b41a5800 --- /dev/null +++ b/src/Shared/runtime/Quic/Interop/MsQuicNativeMethods.cs @@ -0,0 +1,488 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Text; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + /// + /// Contains all native delegates and structs that are used with MsQuic. + /// + internal static unsafe class MsQuicNativeMethods + { + [StructLayout(LayoutKind.Sequential)] + internal struct NativeApi + { + internal uint Version; + + internal IntPtr SetContext; + internal IntPtr GetContext; + internal IntPtr SetCallbackHandler; + + internal IntPtr SetParam; + internal IntPtr GetParam; + + internal IntPtr RegistrationOpen; + internal IntPtr RegistrationClose; + + internal IntPtr SecConfigCreate; + internal IntPtr SecConfigDelete; + + internal IntPtr SessionOpen; + internal IntPtr SessionClose; + internal IntPtr SessionShutdown; + + internal IntPtr ListenerOpen; + internal IntPtr ListenerClose; + internal IntPtr ListenerStart; + internal IntPtr ListenerStop; + + internal IntPtr ConnectionOpen; + internal IntPtr ConnectionClose; + internal IntPtr ConnectionShutdown; + internal IntPtr ConnectionStart; + + internal IntPtr StreamOpen; + internal IntPtr StreamClose; + internal IntPtr StreamStart; + internal IntPtr StreamShutdown; + internal IntPtr StreamSend; + internal IntPtr StreamReceiveComplete; + internal IntPtr StreamReceiveSetEnabled; + } + + internal delegate uint SetContextDelegate( + IntPtr handle, + IntPtr context); + + internal delegate IntPtr GetContextDelegate( + IntPtr handle); + + internal delegate void SetCallbackHandlerDelegate( + IntPtr handle, + Delegate del, + IntPtr context); + + internal delegate uint SetParamDelegate( + IntPtr handle, + uint level, + uint param, + uint bufferLength, + byte* buffer); + + internal delegate uint GetParamDelegate( + IntPtr handle, + uint level, + uint param, + uint* bufferLength, + byte* buffer); + + internal delegate uint RegistrationOpenDelegate(byte[] appName, out IntPtr registrationContext); + + internal delegate void RegistrationCloseDelegate(IntPtr registrationContext); + + internal delegate void SecConfigCreateCompleteDelegate(IntPtr context, uint status, IntPtr securityConfig); + + internal delegate uint SecConfigCreateDelegate( + IntPtr registrationContext, + uint flags, + IntPtr certificate, + [MarshalAs(UnmanagedType.LPStr)]string principal, + IntPtr context, + SecConfigCreateCompleteDelegate completionHandler); + + internal delegate void SecConfigDeleteDelegate( + IntPtr securityConfig); + + internal delegate uint SessionOpenDelegate( + IntPtr registrationContext, + byte[] utf8String, + IntPtr context, + ref IntPtr session); + + internal delegate void SessionCloseDelegate( + IntPtr session); + + internal delegate void SessionShutdownDelegate( + IntPtr session, + uint flags, + ushort errorCode); + + [StructLayout(LayoutKind.Sequential)] + internal struct ListenerEvent + { + internal QUIC_LISTENER_EVENT Type; + internal ListenerEventDataUnion Data; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct ListenerEventDataUnion + { + [FieldOffset(0)] + internal ListenerEventDataNewConnection NewConnection; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ListenerEventDataNewConnection + { + internal IntPtr Info; + internal IntPtr Connection; + internal IntPtr SecurityConfig; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct NewConnectionInfo + { + internal uint QuicVersion; + internal IntPtr LocalAddress; + internal IntPtr RemoteAddress; + internal ushort CryptoBufferLength; + internal ushort AlpnListLength; + internal ushort ServerNameLength; + internal IntPtr CryptoBuffer; + internal IntPtr AlpnList; + internal IntPtr ServerName; + } + + internal delegate uint ListenerCallbackDelegate( + IntPtr listener, + IntPtr context, + ref ListenerEvent evt); + + internal delegate uint ListenerOpenDelegate( + IntPtr session, + ListenerCallbackDelegate handler, + IntPtr context, + out IntPtr listener); + + internal delegate uint ListenerCloseDelegate( + IntPtr listener); + + internal delegate uint ListenerStartDelegate( + IntPtr listener, + ref SOCKADDR_INET localAddress); + + internal delegate uint ListenerStopDelegate( + IntPtr listener); + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataConnected + { + internal bool EarlyDataAccepted; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataShutdownBegin + { + internal uint Status; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataShutdownBeginPeer + { + internal long ErrorCode; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataShutdownComplete + { + internal bool TimedOut; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataLocalAddrChanged + { + internal IntPtr Address; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataPeerAddrChanged + { + internal IntPtr Address; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataNewStream + { + internal IntPtr Stream; + internal QUIC_STREAM_OPEN_FLAG Flags; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataStreamsAvailable + { + internal ushort BiDirectionalCount; + internal ushort UniDirectionalCount; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEventDataIdealSendBuffer + { + internal ulong NumBytes; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct ConnectionEventDataUnion + { + [FieldOffset(0)] + internal ConnectionEventDataConnected Connected; + + [FieldOffset(0)] + internal ConnectionEventDataShutdownBegin ShutdownBegin; + + [FieldOffset(0)] + internal ConnectionEventDataShutdownBeginPeer ShutdownBeginPeer; + + [FieldOffset(0)] + internal ConnectionEventDataShutdownComplete ShutdownComplete; + + [FieldOffset(0)] + internal ConnectionEventDataLocalAddrChanged LocalAddrChanged; + + [FieldOffset(0)] + internal ConnectionEventDataPeerAddrChanged PeerAddrChanged; + + [FieldOffset(0)] + internal ConnectionEventDataNewStream NewStream; + + [FieldOffset(0)] + internal ConnectionEventDataStreamsAvailable StreamsAvailable; + + [FieldOffset(0)] + internal ConnectionEventDataIdealSendBuffer IdealSendBuffer; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct ConnectionEvent + { + internal QUIC_CONNECTION_EVENT Type; + internal ConnectionEventDataUnion Data; + + internal bool EarlyDataAccepted => Data.Connected.EarlyDataAccepted; + internal ulong NumBytes => Data.IdealSendBuffer.NumBytes; + internal uint ShutdownBeginStatus => Data.ShutdownBegin.Status; + internal long ShutdownBeginPeerStatus => Data.ShutdownBeginPeer.ErrorCode; + internal bool ShutdownTimedOut => Data.ShutdownComplete.TimedOut; + internal ushort BiDirectionalCount => Data.StreamsAvailable.BiDirectionalCount; + internal ushort UniDirectionalCount => Data.StreamsAvailable.UniDirectionalCount; + internal QUIC_STREAM_OPEN_FLAG StreamFlags => Data.NewStream.Flags; + } + + internal delegate uint ConnectionCallbackDelegate( + IntPtr connection, + IntPtr context, + ref ConnectionEvent connectionEvent); + + internal delegate uint ConnectionOpenDelegate( + IntPtr session, + ConnectionCallbackDelegate handler, + IntPtr context, + out IntPtr connection); + + internal delegate uint ConnectionCloseDelegate( + IntPtr connection); + + internal delegate uint ConnectionStartDelegate( + IntPtr connection, + ushort family, + [MarshalAs(UnmanagedType.LPStr)] + string serverName, + ushort serverPort); + + internal delegate uint ConnectionShutdownDelegate( + IntPtr connection, + uint flags, + long errorCode); + + [StructLayout(LayoutKind.Sequential)] + internal struct StreamEventDataRecv + { + internal ulong AbsoluteOffset; + internal ulong TotalBufferLength; + internal QuicBuffer* Buffers; + internal uint BufferCount; + internal uint Flags; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct StreamEventDataSendComplete + { + [FieldOffset(0)] + internal byte Canceled; + [FieldOffset(1)] + internal IntPtr ClientContext; + + internal bool IsCanceled() + { + return Canceled != 0; + } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct StreamEventDataPeerSendAbort + { + internal long ErrorCode; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct StreamEventDataPeerRecvAbort + { + internal long ErrorCode; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct StreamEventDataSendShutdownComplete + { + internal byte Graceful; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct StreamEventDataUnion + { + [FieldOffset(0)] + internal StreamEventDataRecv Recv; + + [FieldOffset(0)] + internal StreamEventDataSendComplete SendComplete; + + [FieldOffset(0)] + internal StreamEventDataPeerSendAbort PeerSendAbort; + + [FieldOffset(0)] + internal StreamEventDataPeerRecvAbort PeerRecvAbort; + + [FieldOffset(0)] + internal StreamEventDataSendShutdownComplete SendShutdownComplete; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct StreamEvent + { + internal QUIC_STREAM_EVENT Type; + internal StreamEventDataUnion Data; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SOCKADDR_IN + { + internal ushort sin_family; + internal ushort sin_port; + internal byte sin_addr0; + internal byte sin_addr1; + internal byte sin_addr2; + internal byte sin_addr3; + + internal byte[] Address + { + get + { + return new byte[] { sin_addr0, sin_addr1, sin_addr2, sin_addr3 }; + } + } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SOCKADDR_IN6 + { + internal ushort _family; + internal ushort _port; + internal uint _flowinfo; + internal byte _addr0; + internal byte _addr1; + internal byte _addr2; + internal byte _addr3; + internal byte _addr4; + internal byte _addr5; + internal byte _addr6; + internal byte _addr7; + internal byte _addr8; + internal byte _addr9; + internal byte _addr10; + internal byte _addr11; + internal byte _addr12; + internal byte _addr13; + internal byte _addr14; + internal byte _addr15; + internal uint _scope_id; + + internal byte[] Address + { + get + { + return new byte[] { + _addr0, _addr1, _addr2, _addr3, + _addr4, _addr5, _addr6, _addr7, + _addr8, _addr9, _addr10, _addr11, + _addr12, _addr13, _addr14, _addr15 }; + } + } + } + + [StructLayout(LayoutKind.Explicit, CharSet = CharSet.Ansi)] + internal struct SOCKADDR_INET + { + [FieldOffset(0)] + internal SOCKADDR_IN Ipv4; + [FieldOffset(0)] + internal SOCKADDR_IN6 Ipv6; + [FieldOffset(0)] + internal ushort si_family; + } + + internal delegate uint StreamCallbackDelegate( + IntPtr stream, + IntPtr context, + ref StreamEvent streamEvent); + + internal delegate uint StreamOpenDelegate( + IntPtr connection, + uint flags, + StreamCallbackDelegate handler, + IntPtr context, + out IntPtr stream); + + internal delegate uint StreamStartDelegate( + IntPtr stream, + uint flags); + + internal delegate uint StreamCloseDelegate( + IntPtr stream); + + internal delegate uint StreamShutdownDelegate( + IntPtr stream, + uint flags, + long errorCode); + + internal delegate uint StreamSendDelegate( + IntPtr stream, + QuicBuffer* buffers, + uint bufferCount, + uint flags, + IntPtr clientSendContext); + + internal delegate uint StreamReceiveCompleteDelegate( + IntPtr stream, + ulong bufferLength); + + internal delegate uint StreamReceiveSetEnabledDelegate( + IntPtr stream, + bool enabled); + + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct QuicBuffer + { + internal uint Length; + internal byte* Buffer; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct CertFileParams + { + internal IntPtr CertificateFilePath; + internal IntPtr PrivateKeyFilePath; + } + } +} diff --git a/src/Shared/runtime/Quic/Interop/MsQuicStatusCodes.cs b/src/Shared/runtime/Quic/Interop/MsQuicStatusCodes.cs new file mode 100644 index 000000000000..72c35687ebe5 --- /dev/null +++ b/src/Shared/runtime/Quic/Interop/MsQuicStatusCodes.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal static class MsQuicStatusCodes + { + internal static readonly uint Success = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? Windows.Success : Linux.Success; + internal static readonly uint Pending = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? Windows.Pending : Linux.Pending; + internal static readonly uint InternalError = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? Windows.InternalError : Linux.InternalError; + + // TODO return better error messages here. + public static string GetError(uint status) + { + return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? Windows.GetError(status) : Linux.GetError(status); + } + + private static class Windows + { + internal const uint Success = 0; + internal const uint Pending = 0x703E5; + internal const uint Continue = 0x704DE; + internal const uint OutOfMemory = 0x8007000E; + internal const uint InvalidParameter = 0x80070057; + internal const uint InvalidState = 0x8007139F; + internal const uint NotSupported = 0x80004002; + internal const uint NotFound = 0x80070490; + internal const uint BufferTooSmall = 0x8007007A; + internal const uint HandshakeFailure = 0x80410000; + internal const uint Aborted = 0x80004004; + internal const uint AddressInUse = 0x80072740; + internal const uint ConnectionTimeout = 0x800704CF; + internal const uint ConnectionIdle = 0x800704D4; + internal const uint InternalError = 0x80004005; + internal const uint ServerBusy = 0x800704C9; + internal const uint ProtocolError = 0x800704CD; + internal const uint HostUnreachable = 0x800704D0; + internal const uint VerNegError = 0x80410001; + + // TODO return better error messages here. + public static string GetError(uint status) + { + return status switch + { + Success => "SUCCESS", + Pending => "PENDING", + Continue => "CONTINUE", + OutOfMemory => "OUT_OF_MEMORY", + InvalidParameter => "INVALID_PARAMETER", + InvalidState => "INVALID_STATE", + NotSupported => "NOT_SUPPORTED", + NotFound => "NOT_FOUND", + BufferTooSmall => "BUFFER_TOO_SMALL", + HandshakeFailure => "HANDSHAKE_FAILURE", + Aborted => "ABORTED", + AddressInUse => "ADDRESS_IN_USE", + ConnectionTimeout => "CONNECTION_TIMEOUT", + ConnectionIdle => "CONNECTION_IDLE", + InternalError => "INTERNAL_ERROR", + ServerBusy => "SERVER_BUSY", + ProtocolError => "PROTOCOL_ERROR", + VerNegError => "VER_NEG_ERROR", + _ => status.ToString() + }; + } + } + + private static class Linux + { + internal const uint Success = 0; + internal const uint Pending = unchecked((uint)-2); + internal const uint Continue = unchecked((uint)-1); + internal const uint OutOfMemory = 12; + internal const uint InvalidParameter = 22; + internal const uint InvalidState = 200000002; + internal const uint NotSupported = 95; + internal const uint NotFound = 2; + internal const uint BufferTooSmall = 75; + internal const uint HandshakeFailure = 200000009; + internal const uint Aborted = 200000008; + internal const uint AddressInUse = 98; + internal const uint ConnectionTimeout = 110; + internal const uint ConnectionIdle = 200000011; + internal const uint InternalError = 200000012; + internal const uint ServerBusy = 200000007; + internal const uint ProtocolError = 200000013; + internal const uint VerNegError = 200000014; + + // TODO return better error messages here. + public static string GetError(uint status) + { + return status switch + { + Success => "SUCCESS", + Pending => "PENDING", + Continue => "CONTINUE", + OutOfMemory => "OUT_OF_MEMORY", + InvalidParameter => "INVALID_PARAMETER", + InvalidState => "INVALID_STATE", + NotSupported => "NOT_SUPPORTED", + NotFound => "NOT_FOUND", + BufferTooSmall => "BUFFER_TOO_SMALL", + HandshakeFailure => "HANDSHAKE_FAILURE", + Aborted => "ABORTED", + AddressInUse => "ADDRESS_IN_USE", + ConnectionTimeout => "CONNECTION_TIMEOUT", + ConnectionIdle => "CONNECTION_IDLE", + InternalError => "INTERNAL_ERROR", + ServerBusy => "SERVER_BUSY", + ProtocolError => "PROTOCOL_ERROR", + VerNegError => "VER_NEG_ERROR", + _ => status.ToString() + }; + } + } + } +} diff --git a/src/Shared/runtime/Quic/Interop/MsQuicStatusHelper.cs b/src/Shared/runtime/Quic/Interop/MsQuicStatusHelper.cs new file mode 100644 index 000000000000..f08eb861d2e6 --- /dev/null +++ b/src/Shared/runtime/Quic/Interop/MsQuicStatusHelper.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; + +namespace System.Net.Quic.Implementations.MsQuic.Internal +{ + internal static class MsQuicStatusHelper + { + internal static bool SuccessfulStatusCode(uint status) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return status < 0x80000000; + } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return (int)status <= 0; + } + + return false; + } + } +} diff --git a/src/Shared/runtime/Quic/NetEventSource.Quic.cs b/src/Shared/runtime/Quic/NetEventSource.Quic.cs new file mode 100644 index 000000000000..921808829dcb --- /dev/null +++ b/src/Shared/runtime/Quic/NetEventSource.Quic.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.Tracing; + +namespace System.Net +{ + [EventSource(Name = "Microsoft-System-Net-Quic")] + internal sealed partial class NetEventSource : EventSource + { + } +} diff --git a/src/Shared/runtime/Quic/QuicClientConnectionOptions.cs b/src/Shared/runtime/Quic/QuicClientConnectionOptions.cs new file mode 100644 index 000000000000..a9a9b0ec40c4 --- /dev/null +++ b/src/Shared/runtime/Quic/QuicClientConnectionOptions.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Security; + +namespace System.Net.Quic +{ + /// + /// Options to provide to the when connecting to a Listener. + /// + internal class QuicClientConnectionOptions + { + /// + /// Client authentication options to use when establishing a . + /// + public SslClientAuthenticationOptions ClientAuthenticationOptions { get; set; } + + /// + /// The local endpoint that will be bound to. + /// + public IPEndPoint LocalEndPoint { get; set; } + + /// + /// The endpoint to connect to. + /// + public IPEndPoint RemoteEndPoint { get; set; } + + /// + /// Limit on the number of bidirectional streams the peer connection can create + /// on an accepted connection. + /// Default is 100. + /// + // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. + public long MaxBidirectionalStreams { get; set; } = 100; + + /// + /// Limit on the number of unidirectional streams the peer connection can create + /// on an accepted connection. + /// Default is 100. + /// + // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. + public long MaxUnidirectionalStreams { get; set; } = 100; + + /// + /// Idle timeout for connections, afterwhich the connection will be closed. + /// + public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2); + } +} diff --git a/src/Shared/runtime/Quic/QuicConnection.cs b/src/Shared/runtime/Quic/QuicConnection.cs new file mode 100644 index 000000000000..877421a94f01 --- /dev/null +++ b/src/Shared/runtime/Quic/QuicConnection.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Quic.Implementations; +using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Net.Security; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic +{ + internal sealed class QuicConnection : IDisposable + { + private readonly QuicConnectionProvider _provider; + + public static bool IsQuicSupported => MsQuicApi.IsQuicSupported; + + /// + /// Create an outbound QUIC connection. + /// + /// The remote endpoint to connect to. + /// TLS options + /// The local endpoint to connect from. + public QuicConnection(IPEndPoint remoteEndPoint, SslClientAuthenticationOptions sslClientAuthenticationOptions, IPEndPoint localEndPoint = null) + : this(QuicImplementationProviders.Default, remoteEndPoint, sslClientAuthenticationOptions, localEndPoint) + { + } + + // !!! TEMPORARY: Remove "implementationProvider" before shipping + public QuicConnection(QuicImplementationProvider implementationProvider, IPEndPoint remoteEndPoint, SslClientAuthenticationOptions sslClientAuthenticationOptions, IPEndPoint localEndPoint = null) + : this(implementationProvider, new QuicClientConnectionOptions() { RemoteEndPoint = remoteEndPoint, ClientAuthenticationOptions = sslClientAuthenticationOptions, LocalEndPoint = localEndPoint }) + { + } + + public QuicConnection(QuicImplementationProvider implementationProvider, QuicClientConnectionOptions options) + { + _provider = implementationProvider.CreateConnection(options); + } + + internal QuicConnection(QuicConnectionProvider provider) + { + _provider = provider; + } + + /// + /// Indicates whether the QuicConnection is connected. + /// + public bool Connected => _provider.Connected; + + public IPEndPoint LocalEndPoint => _provider.LocalEndPoint; + + public IPEndPoint RemoteEndPoint => _provider.RemoteEndPoint; + + public SslApplicationProtocol NegotiatedApplicationProtocol => _provider.NegotiatedApplicationProtocol; + + /// + /// Connect to the remote endpoint. + /// + /// + /// + public ValueTask ConnectAsync(CancellationToken cancellationToken = default) => _provider.ConnectAsync(cancellationToken); + + /// + /// Create an outbound unidirectional stream. + /// + /// + public QuicStream OpenUnidirectionalStream() => new QuicStream(_provider.OpenUnidirectionalStream()); + + /// + /// Create an outbound bidirectional stream. + /// + /// + public QuicStream OpenBidirectionalStream() => new QuicStream(_provider.OpenBidirectionalStream()); + + /// + /// Accept an incoming stream. + /// + /// + public async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) => new QuicStream(await _provider.AcceptStreamAsync(cancellationToken).ConfigureAwait(false)); + + /// + /// Close the connection and terminate any active streams. + /// + public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) => _provider.CloseAsync(errorCode, cancellationToken); + + public void Dispose() => _provider.Dispose(); + + /// + /// Gets the maximum number of bidirectional streams that can be made to the peer. + /// + public long GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount(); + + /// + /// Gets the maximum number of unidirectional streams that can be made to the peer. + /// + public long GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount(); + } +} diff --git a/src/Shared/runtime/Quic/QuicConnectionAbortedException.cs b/src/Shared/runtime/Quic/QuicConnectionAbortedException.cs new file mode 100644 index 000000000000..41f4b329983e --- /dev/null +++ b/src/Shared/runtime/Quic/QuicConnectionAbortedException.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic +{ + internal class QuicConnectionAbortedException : QuicException + { + internal QuicConnectionAbortedException(long errorCode) + : this(SR.Format(SR.net_quic_connectionaborted, errorCode), errorCode) + { + } + + public QuicConnectionAbortedException(string message, long errorCode) + : base (message) + { + ErrorCode = errorCode; + } + + public long ErrorCode { get; } + } +} diff --git a/src/Shared/runtime/Quic/QuicException.cs b/src/Shared/runtime/Quic/QuicException.cs new file mode 100644 index 000000000000..843c2f75924c --- /dev/null +++ b/src/Shared/runtime/Quic/QuicException.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic +{ + internal class QuicException : Exception + { + public QuicException(string message) + : base (message) + { + } + } +} diff --git a/src/Shared/runtime/Quic/QuicImplementationProviders.cs b/src/Shared/runtime/Quic/QuicImplementationProviders.cs new file mode 100644 index 000000000000..66a7e0d6dfb7 --- /dev/null +++ b/src/Shared/runtime/Quic/QuicImplementationProviders.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic +{ + internal static class QuicImplementationProviders + { + public static Implementations.QuicImplementationProvider Mock { get; } = new Implementations.Mock.MockImplementationProvider(); + public static Implementations.QuicImplementationProvider MsQuic { get; } = new Implementations.MsQuic.MsQuicImplementationProvider(); + public static Implementations.QuicImplementationProvider Default => MsQuic; + } +} diff --git a/src/Shared/runtime/Quic/QuicListener.cs b/src/Shared/runtime/Quic/QuicListener.cs new file mode 100644 index 000000000000..8fb0c1e33756 --- /dev/null +++ b/src/Shared/runtime/Quic/QuicListener.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Quic.Implementations; +using System.Net.Security; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic +{ + internal sealed class QuicListener : IDisposable + { + private readonly QuicListenerProvider _provider; + + /// + /// Create a QUIC listener on the specified local endpoint and start listening. + /// + /// The local endpoint to listen on. + /// TLS options for the listener. + public QuicListener(IPEndPoint listenEndPoint, SslServerAuthenticationOptions sslServerAuthenticationOptions) + : this(QuicImplementationProviders.Default, listenEndPoint, sslServerAuthenticationOptions) + { + } + + // !!! TEMPORARY: Remove "implementationProvider" before shipping + public QuicListener(QuicImplementationProvider implementationProvider, IPEndPoint listenEndPoint, SslServerAuthenticationOptions sslServerAuthenticationOptions) + : this(implementationProvider, new QuicListenerOptions() { ListenEndPoint = listenEndPoint, ServerAuthenticationOptions = sslServerAuthenticationOptions }) + { + } + + public QuicListener(QuicImplementationProvider implementationProvider, QuicListenerOptions options) + { + _provider = implementationProvider.CreateListener(options); + } + + public IPEndPoint ListenEndPoint => _provider.ListenEndPoint; + + /// + /// Accept a connection. + /// + /// + public async ValueTask AcceptConnectionAsync(CancellationToken cancellationToken = default) => + new QuicConnection(await _provider.AcceptConnectionAsync(cancellationToken).ConfigureAwait(false)); + + public void Start() => _provider.Start(); + + /// + /// Stop listening and close the listener. + /// + public void Close() => _provider.Close(); + + public void Dispose() => _provider.Dispose(); + } +} diff --git a/src/Shared/runtime/Quic/QuicListenerOptions.cs b/src/Shared/runtime/Quic/QuicListenerOptions.cs new file mode 100644 index 000000000000..f9eae30c3d5a --- /dev/null +++ b/src/Shared/runtime/Quic/QuicListenerOptions.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Security; + +namespace System.Net.Quic +{ + /// + /// Options to provide to the . + /// + internal class QuicListenerOptions + { + /// + /// Server Ssl options to use for ALPN, SNI, etc. + /// + public SslServerAuthenticationOptions ServerAuthenticationOptions { get; set; } + + /// + /// Optional path to certificate file to configure the security configuration. + /// + public string CertificateFilePath { get; set; } + + /// + /// Optional path to private key file to configure the security configuration. + /// + public string PrivateKeyFilePath { get; set; } + + /// + /// The endpoint to listen on. + /// + public IPEndPoint ListenEndPoint { get; set; } + + /// + /// Number of connections to be held without accepting the connection. + /// + public int ListenBacklog { get; set; } = 512; + + /// + /// Limit on the number of bidirectional streams an accepted connection can create + /// back to the client. + /// Default is 100. + /// + // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. + public long MaxBidirectionalStreams { get; set; } = 100; + + /// + /// Limit on the number of unidirectional streams the peer connection can create. + /// Default is 100. + /// + // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. + public long MaxUnidirectionalStreams { get; set; } = 100; + + /// + /// Idle timeout for connections, afterwhich the connection will be closed. + /// + public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(10); + } +} diff --git a/src/Shared/runtime/Quic/QuicOperationAbortedException.cs b/src/Shared/runtime/Quic/QuicOperationAbortedException.cs new file mode 100644 index 000000000000..25cd145ee65e --- /dev/null +++ b/src/Shared/runtime/Quic/QuicOperationAbortedException.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic +{ + internal class QuicOperationAbortedException : QuicException + { + internal QuicOperationAbortedException() + : base(SR.net_quic_operationaborted) + { + } + + public QuicOperationAbortedException(string message) : base(message) + { + } + } +} diff --git a/src/Shared/runtime/Quic/QuicStream.cs b/src/Shared/runtime/Quic/QuicStream.cs new file mode 100644 index 000000000000..6e52bb07530d --- /dev/null +++ b/src/Shared/runtime/Quic/QuicStream.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.IO; +using System.Net.Quic.Implementations; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic +{ + internal sealed class QuicStream : Stream + { + private readonly QuicStreamProvider _provider; + + internal QuicStream(QuicStreamProvider provider) + { + _provider = provider; + } + + // + // Boilerplate implementation stuff + // + + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => + TaskToApm.Begin(ReadAsync(buffer, offset, count, default), callback, state); + + public override int EndRead(IAsyncResult asyncResult) => + TaskToApm.End(asyncResult); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => + TaskToApm.Begin(WriteAsync(buffer, offset, count, default), callback, state); + + public override void EndWrite(IAsyncResult asyncResult) => + TaskToApm.End(asyncResult); + + private static void ValidateBufferArgs(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + + if ((uint)offset > buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + + if ((uint)count > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException(nameof(count)); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + ValidateBufferArgs(buffer, offset, count); + return Read(buffer.AsSpan(offset, count)); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArgs(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + ValidateBufferArgs(buffer, offset, count); + Write(buffer.AsSpan(offset, count)); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArgs(buffer, offset, count); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + } + + /// + /// QUIC stream ID. + /// + public long StreamId => _provider.StreamId; + + public override bool CanRead => _provider.CanRead; + + public override int Read(Span buffer) => _provider.Read(buffer); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => _provider.ReadAsync(buffer, cancellationToken); + + public override bool CanWrite => _provider.CanWrite; + + public override void Write(ReadOnlySpan buffer) => _provider.Write(buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, cancellationToken); + + public override void Flush() => _provider.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _provider.FlushAsync(cancellationToken); + + public void AbortRead(long errorCode) => _provider.AbortRead(errorCode); + + public void AbortWrite(long errorCode) => _provider.AbortWrite(errorCode); + + public ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, endStream, cancellationToken); + + public ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, cancellationToken); + + public ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); + + public ValueTask WriteAsync(ReadOnlyMemory> buffers, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, cancellationToken); + + public ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); + + public ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownWriteCompleted(cancellationToken); + + public void Shutdown() => _provider.Shutdown(); + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _provider.Dispose(); + } + } + } +} diff --git a/src/Shared/runtime/Quic/QuicStreamAbortedException.cs b/src/Shared/runtime/Quic/QuicStreamAbortedException.cs new file mode 100644 index 000000000000..6e25335f9992 --- /dev/null +++ b/src/Shared/runtime/Quic/QuicStreamAbortedException.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net.Quic +{ + internal class QuicStreamAbortedException : QuicException + { + internal QuicStreamAbortedException(long errorCode) + : this(SR.Format(SR.net_quic_streamaborted, errorCode), errorCode) + { + } + + public QuicStreamAbortedException(string message, long errorCode) + : base(message) + { + ErrorCode = errorCode; + } + + public long ErrorCode { get; } + } +} diff --git a/src/Shared/runtime/SR.Quic.cs b/src/Shared/runtime/SR.Quic.cs new file mode 100644 index 000000000000..6a756f9ca6f9 --- /dev/null +++ b/src/Shared/runtime/SR.Quic.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace System.Net.Quic +{ + internal static partial class SR + { + // The resource generator used in AspNetCore does not create this method. This file fills in that functional gap + // so we don't have to modify the shared source. + internal static string Format(string resourceFormat, params object[] args) + { + if (args != null) + { + return string.Format(resourceFormat, args); + } + + return resourceFormat; + } + } +} diff --git a/src/Shared/runtime/SR.resx b/src/Shared/runtime/SR.resx index bcd721a2c5f8..c6c8b9bd5b25 100644 --- a/src/Shared/runtime/SR.resx +++ b/src/Shared/runtime/SR.resx @@ -156,4 +156,16 @@ Request headers must contain only ASCII characters. + + Connection aborted by peer ({0}). + + + QUIC is not supported on this platform. See http://aka.ms/dotnetquic + + + Operation aborted. + + + Stream aborted by peer ({0}). + \ No newline at end of file