diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 339e6be972..3df369a2f6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -230,7 +230,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN } else if (!string.IsNullOrWhiteSpace(dataSource.InstanceName)) { - postfix = dataSource.InstanceName; + postfix = dataSource._connectionProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName; } SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerName {0}, InstanceName {1}, Port {2}, postfix {3}", dataSource?.ServerName, dataSource?.InstanceName, dataSource?.Port, postfix); @@ -317,7 +317,7 @@ private static SNITCPHandle CreateTcpHandle( { try { - port = isAdminConnection ? + details.ResolvedPort = port = isAdminConnection ? SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) : SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference); } @@ -436,6 +436,11 @@ internal enum Protocol { TCP, NP, None, Admin }; /// internal int Port { get; private set; } = -1; + /// + /// The port resolved by SSRP when InstanceName is specified + /// + internal int ResolvedPort { get; set; } = -1; + /// /// Provides the inferred Instance Name from Server Data Source /// diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 07f866f98b..e167a264a2 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -1073,6 +1073,21 @@ public static string GetMachineFQDN(string hostname) return fqdn.ToString(); } + public static bool IsNotLocalhost() + { + // get the tcp connection string + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + + string hostname = ""; + + // parse the datasource + ParseDataSource(builder.DataSource, out hostname, out _, out _); + + // hostname must not be localhost, ., 127.0.0.1 nor ::1 + return !(new string[] { "localhost", ".", "127.0.0.1", "::1" }).Contains(hostname.ToLowerInvariant()); + + } + private static bool RunningAsUWPApp() { if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index 311ce5fce5..087b44d964 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -5,6 +5,7 @@ using System; using System.Net; using System.Net.Sockets; +using System.Reflection; using System.Text; using System.Threading.Tasks; using Xunit; @@ -83,6 +84,138 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove } } + // Note: This Unit test was tested in a domain-joined VM connecting to a remote + // SQL Server using Kerberos in the same domain. + [ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false + [ConditionalFact(nameof(IsKerberos))] + public static void PortNumberInSPNTest() + { + string connStr = DataTestUtility.TCPConnectionString; + // If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true + if (DataTestUtility.IsIntegratedSecuritySetup()) + { + string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" }; + connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true"; + } + + SqlConnectionStringBuilder builder = new(connStr); + + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name"); + + bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName); + Assert.True(condition, "Browser service is not running or instance name is invalid"); + + if (condition) + { + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection); + using SqlDataReader reader = command.ExecuteReader(); + Assert.True(reader.Read(), "Expected to receive one row data"); + Assert.Equal("KERBEROS", reader.GetString(0)); + int localTcpPort = reader.GetInt32(1); + + int spnPort = -1; + string spnInfo = GetSPNInfo(builder.DataSource, out spnPort); + + // sample output to validate = MSSQLSvc/machine.domain.tld:spnPort" + Assert.Contains($"MSSQLSvc/{hostname}", spnInfo); + // the local_tcp_port should be the same as the inferred SPN port from instance name + Assert.Equal(localTcpPort, spnPort); + } + } + + private static string GetSPNInfo(string datasource, out int out_port) + { + Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection)); + + // Get all required types using reflection + Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy"); + Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP"); + Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource"); + Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer"); + + // Used in Datasource constructor param type array + Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) }; + + // Used in GetSqlServerSPNs function param types array + Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) }; + + // GetPortByInstanceName parameters array + Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) }; + + // TimeoutTimer.StartSecondsTimeout params + Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) }; + + // Get all types constructors + ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); + ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + + // Instantiate SNIProxy + object sniProxy = sniProxyCtor.Invoke(new object[] { }); + + // Instantiate datasource + object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource }); + + // Instantiate SSRP + object ssrp = SSRPCtor.Invoke(new object[] { }); + + // Instantiate TimeoutTimer + object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { }); + + // Get TimeoutTimer.StartSecondsTimeout Method + MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null); + // Create a timeoutTimer that expires in 30 seconds + timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 }); + + // Parse the datasource to separate the server name and instance name + MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); + object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource }); + + // Get the GetPortByInstanceName method of SSRP + MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null); + + // Get the server name + PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString(); + + // Get the instance name + PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString(); + + // Get the port number using the GetPortByInstanceName method of SSRP + object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 }); + + // Set the resolved port property of datasource + PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null); + + // Prepare the GetSqlServerSPNs method + string serverSPN = ""; + MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); + + // Finally call GetSqlServerSPNs + byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN }); + + // Example result: MSSQLSvc/machine.domain.tld:port" + string spnInfo = Encoding.Unicode.GetString(result[0]); + + out_port = (int)port; + + return spnInfo; + } + + private static bool IsKerberos() + { + return (DataTestUtility.AreConnStringsSetup() + && DataTestUtility.IsNotLocalhost() + && DataTestUtility.IsKerberosTest + && DataTestUtility.IsNotAzureServer() + && DataTestUtility.IsNotAzureSynapse()); + } + private static bool IsBrowserAlive(string browserHostname) { const byte ClntUcastEx = 0x03;