diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Select.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Select.cs new file mode 100644 index 00000000000000..245e41d81138b5 --- /dev/null +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Select.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class Sys + { + [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Select")] + internal static unsafe partial Error Select(Span readFDs, int readFDsLength, Span writeFDs, int writeFDsLength, Span checkError, int checkErrorLength, int timeout, int maxFd, out int triggered); + } +} diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index 5198e809061635..e7a4cfbcce3c9f 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -267,6 +267,8 @@ Link="Common\Interop\Unix\System.Native\Interop.ReceiveMessage.cs" /> + 60 e.g. close to 64 we have in some other places + Span readFDs = checkRead?.Count > MaxStackAllocCount ? new int[checkRead.Count] : stackalloc int[checkRead?.Count ?? 0]; + Span writeFDs = checkWrite?.Count > MaxStackAllocCount ? new int[checkWrite.Count] : stackalloc int[checkWrite?.Count ?? 0]; + Span errorFDs = checkError?.Count > MaxStackAllocCount ? new int[checkError.Count] : stackalloc int[checkError?.Count ?? 0]; + + int refsAdded = 0; + int maxFd = 0; + try + { + AddDesriptors(readFDs, checkRead, ref refsAdded, ref maxFd); + AddDesriptors(writeFDs, checkWrite, ref refsAdded, ref maxFd); + AddDesriptors(errorFDs, checkError, ref refsAdded, ref maxFd); + + int triggered = 0; + Interop.Error err = Interop.Sys.Select(readFDs, readFDs.Length, writeFDs, writeFDs.Length, errorFDs, errorFDs.Length, microseconds, maxFd, out triggered); + if (err != Interop.Error.SUCCESS) + { + return GetSocketErrorForErrorCode(err); + } + + Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded); + Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded); + Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded); + + if (triggered == 0) + { + checkRead?.Clear(); + checkWrite?.Clear(); + checkError?.Clear(); + } + else + { + FilterSelectList(checkRead, readFDs); + FilterSelectList(checkWrite, writeFDs); + FilterSelectList(checkError, errorFDs); + } + } + finally + { + // This order matches with the AddToPollArray calls + // to release only the handles that were ref'd. + Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded); + Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded); + Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded); + Debug.Assert(refsAdded == 0); + } + + return (SocketError)0; + } + + private static void AddDesriptors(Span buffer, IList? socketList, ref int refsAdded, ref int maxFd) + { + if (socketList == null || socketList.Count == 0 ) + { + return; + } + + Debug.Assert(buffer.Length == socketList.Count); + for (int i = 0; i < socketList.Count; i++) + { + Socket? socket = socketList[i] as Socket; + if (socket == null) + { + throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList)); + } + + if (socket.Handle > maxFd) + { + maxFd = (int)socket.Handle; + } + + bool success = false; + socket.InternalSafeHandle.DangerousAddRef(ref success); + buffer[i] = (int)socket.InternalSafeHandle.DangerousGetHandle(); + + refsAdded++; + } + } + + private static void FilterSelectList(IList? socketList, Span results) + { + if (socketList == null) + return; + + // This loop can be O(n^2) in the unexpected and worst case. Some more thoughts are written in FilterPollList that does exactly same operation. + + for (int i = socketList.Count - 1; i >= 0; --i) + { + if (results[i] == 0) + { + socketList.RemoveAt(i); + } + } + } + private static unsafe SocketError SelectViaPoll( IList? checkRead, int checkReadInitialCount, IList? checkWrite, int checkWriteInitialCount, diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs index 0b49749e09aeec..9c97069e3110fc 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs @@ -5,7 +5,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; - +using Microsoft.DotNet.XUnitExtensions; using Xunit; using Xunit.Abstractions; @@ -21,7 +21,7 @@ public SelectTest(ITestOutputHelper output) } private const int SmallTimeoutMicroseconds = 10 * 1000; - private const int FailTimeoutMicroseconds = 30 * 1000 * 1000; + internal const int FailTimeoutMicroseconds = 30 * 1000 * 1000; [SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")] [Theory] @@ -78,6 +78,82 @@ public void Select_ReadWrite_AllReady(int reads, int writes) } } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Select_ReadError_Success(bool dispose) + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + sender.Connect(listener.LocalEndPoint); + using Socket receiver = listener.Accept(); + + if (dispose) + { + sender.Dispose(); + } + else + { + sender.Send(new byte[] { 1 }); + } + + var readList = new List { receiver }; + var errorList = new List { receiver }; + Socket.Select(readList, null, errorList, -1); + if (dispose) + { + Assert.True(readList.Count == 1 || errorList.Count == 1); + } + else + { + Assert.Equal(1, readList.Count); + Assert.Equal(0, errorList.Count); + } + } + + [Fact] + public void Select_WriteError_Success() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + sender.Connect(listener.LocalEndPoint); + using Socket receiver = listener.Accept(); + + var writeList = new List { receiver }; + var errorList = new List { receiver }; + Socket.Select(null, writeList, errorList, -1); + Assert.Equal(1, writeList.Count); + Assert.Equal(0, errorList.Count); + } + + [Fact] + public void Select_ReadWriteError_Success() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified); + + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + sender.Connect(listener.LocalEndPoint); + using Socket receiver = listener.Accept(); + + sender.Send(new byte[] { 1 }); + receiver.Poll(FailTimeoutMicroseconds, SelectMode.SelectRead); + var readList = new List { receiver }; + var writeList = new List { receiver }; + var errorList = new List { receiver }; + Socket.Select(readList, writeList, errorList, -1); + Assert.Equal(1, readList.Count); + Assert.Equal(1, writeList.Count); + Assert.Equal(0, errorList.Count); + } + [Theory] [InlineData(2, 0)] [InlineData(2, 1)] @@ -109,7 +185,6 @@ public void Select_SocketAlreadyClosed_AllSocketsClosableAfterException(int sock } } - [SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")] [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)] public void Select_ReadError_NoneReady_ManySockets() @@ -245,7 +320,7 @@ public void Poll_ReadReady_LongTimeouts(int microsecondsTimeout) } } - private static KeyValuePair CreateConnectedSockets() + internal static KeyValuePair CreateConnectedSockets() { using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { @@ -342,5 +417,29 @@ private static void DoAccept(Socket listenSocket, int connectionsToAccept) } } } + + [ConditionalFact] + public void Select_LargeNumber_Succcess() + { + const int MaxSockets = 1025; + KeyValuePair[] socketPairs; + try + { + // we try to shoot for more socket than FD_SETSIZE (that is typically 1024) + socketPairs = Enumerable.Range(0, MaxSockets).Select(_ => SelectTest.CreateConnectedSockets()).ToArray(); + } + catch + { + throw new SkipTestException("Unable to open large count number of socket"); + } + + var readList = new List(socketPairs.Select(p => p.Key).ToArray()); + + // Try to write and read on last sockets + (Socket reader, Socket writer) = socketPairs[MaxSockets - 1]; + writer.Send(new byte[1]); + Socket.Select(readList, null, null, SelectTest.FailTimeoutMicroseconds); + Assert.Equal(1, readList.Count); + } } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketOptionNameTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketOptionNameTest.cs index a62ccc405b42c0..26752dd31f5dcb 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketOptionNameTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketOptionNameTest.cs @@ -239,11 +239,8 @@ public void FailedConnect_GetSocketOption_SocketOptionNameError(bool simpleGet) Assert.ThrowsAny(() => client.Connect(server.LocalEndPoint)); } - // Verify via Select that there's an error - const int FailedTimeout = 10 * 1000 * 1000; // 10 seconds - var errorList = new List { client }; - Socket.Select(null, null, errorList, FailedTimeout); - Assert.Equal(1, errorList.Count); + // Verify via Poll that there's an error + Assert.True(client.Poll(10_000_000, SelectMode.SelectError)); // Get the last error and validate it's what's expected int errorCode; diff --git a/src/native/libs/System.Native/entrypoints.c b/src/native/libs/System.Native/entrypoints.c index 51c761109159b5..03dd2f65083e01 100644 --- a/src/native/libs/System.Native/entrypoints.c +++ b/src/native/libs/System.Native/entrypoints.c @@ -282,6 +282,7 @@ static const Entry s_sysNative[] = DllImportEntry(SystemNative_GetGroupName) DllImportEntry(SystemNative_GetUInt64OSThreadId) DllImportEntry(SystemNative_TryGetUInt32OSThreadId) + DllImportEntry(SystemNative_Select) }; EXTERN_C const void* SystemResolveDllImport(const char* name); diff --git a/src/native/libs/System.Native/pal_networking.c b/src/native/libs/System.Native/pal_networking.c index 2916c33206cb08..dc727fe5465045 100644 --- a/src/native/libs/System.Native/pal_networking.c +++ b/src/native/libs/System.Native/pal_networking.c @@ -1,13 +1,16 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if defined(__APPLE__) && __APPLE__ +#define _DARWIN_UNLIMITED_SELECT 1 +#endif + #include "pal_config.h" #include "pal_networking.h" #include "pal_safecrt.h" #include "pal_utilities.h" #include #include - #include #include #include @@ -21,6 +24,7 @@ #include #elif HAVE_SYS_POLL_H #include +#include #endif #if HAVE_SYS_PROCINFO_H #include @@ -31,7 +35,6 @@ #include #include #include -#include #include #include #include @@ -2762,6 +2765,97 @@ int32_t SystemNative_GetBytesAvailable(intptr_t socket, int32_t* available) return Error_SUCCESS; } +int32_t SystemNative_Select(int* readFds, int readFdsCount, int* writeFds, int writeFdsCount, int* errorFds, int errorFdsCount, int32_t microseconds, int maxFd, int* triggered) +{ +#ifdef _DARWIN_UNLIMITED_SELECT + fd_set readSet; + fd_set writeSet; + fd_set errorSet; + + fd_set* readSetPtr; + fd_set* writeSetPtr; + fd_set* errorSetPtr; + + if (maxFd < FD_SETSIZE) + { + FD_ZERO(&readSet); + FD_ZERO(&writeSet); + FD_ZERO(&errorSet); + readSetPtr = readFdsCount == 0 ? NULL : &readSet; + writeSetPtr= writeFdsCount == 0 ? NULL : &writeSet; + errorSetPtr = errorFdsCount == 0 ? NULL : &errorSet; + } + else + { + readSetPtr = readFdsCount == 0 ? NULL : calloc( __DARWIN_howmany(maxFd, __DARWIN_NFDBITS), sizeof(int32_t)); + writeSetPtr = writeFdsCount == 0 ? NULL : calloc( __DARWIN_howmany(maxFd, __DARWIN_NFDBITS), sizeof(int32_t)); + errorSetPtr = errorFdsCount == 0 ? NULL : calloc( __DARWIN_howmany(maxFd, __DARWIN_NFDBITS), sizeof(int32_t)); + } + + + struct timeval timeout; + timeout.tv_sec = microseconds / 1000000; + timeout.tv_usec = microseconds % 1000000; + + int fd; + for (int i = 0 ; i < readFdsCount; i++) + { + fd = *(readFds + i); + __DARWIN_FD_SET(fd, readSetPtr); + } + for (int i = 0 ; i < writeFdsCount; i++) + { + fd = *(writeFds + i); + __DARWIN_FD_SET(fd, writeSetPtr); + } + for (int i = 0 ; i < errorFdsCount; i++) + { + fd = *(errorFds + i); + __DARWIN_FD_SET(fd, errorSetPtr); + } + + *triggered = select(maxFd + 1, readSetPtr, writeSetPtr, errorSetPtr, microseconds < 0 ? NULL : &timeout); + + if (*triggered < 0) + { + return SystemNative_ConvertErrorPlatformToPal(errno); + } + + for (int i = 0 ; i < readFdsCount; i++) + { + fd = *(readFds + i); + *(readFds + i) = __DARWIN_FD_ISSET(fd, readSetPtr); + } + for (int i = 0 ; i < writeFdsCount; i++) + { + fd = *(writeFds + i); + *(writeFds + i) = __DARWIN_FD_ISSET(fd, writeSetPtr); + } + for (int i = 0 ; i < errorFdsCount; i++) + { + fd = *(errorFds + i); + *(errorFds + i) = __DARWIN_FD_ISSET(fd, errorSetPtr); + } + + if (maxFd >= FD_SETSIZE) + { + free(readSetPtr); + free(writeSetPtr); + free(errorSetPtr); + } + + return Error_SUCCESS; +#else + // avoid unused parameters warnings + (void*)readFds; + (void*)writeFds; + (void*)errorFds; + (void*)triggered; + readFdsCount + writeFdsCount + errorFdsCount + microseconds + maxFd; + return SystemNative_ConvertErrorPlatformToPal(ENOTSUP); +#endif +} + #if HAVE_EPOLL static const size_t SocketEventBufferElementSize = sizeof(struct epoll_event) > sizeof(SocketEvent) ? sizeof(struct epoll_event) : sizeof(SocketEvent); diff --git a/src/native/libs/System.Native/pal_networking.h b/src/native/libs/System.Native/pal_networking.h index f10f490c5db915..a0904f295267dc 100644 --- a/src/native/libs/System.Native/pal_networking.h +++ b/src/native/libs/System.Native/pal_networking.h @@ -427,3 +427,5 @@ PALEXPORT int32_t SystemNative_SendFile(intptr_t out_fd, intptr_t in_fd, int64_t PALEXPORT int32_t SystemNative_Disconnect(intptr_t socket); PALEXPORT uint32_t SystemNative_InterfaceNameToIndex(char* interfaceName); + +PALEXPORT int32_t SystemNative_Select(int* readFds, int readFdsCount, int* writeFds, int writeFdsCount, int* errorFds, int errorFdsCount, int32_t microseconds, int32_t maxFd, int* triggered);