diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
index 339e6be972..3df369a2f6 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
@@ -230,7 +230,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
}
else if (!string.IsNullOrWhiteSpace(dataSource.InstanceName))
{
- postfix = dataSource.InstanceName;
+ postfix = dataSource._connectionProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName;
}
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerName {0}, InstanceName {1}, Port {2}, postfix {3}", dataSource?.ServerName, dataSource?.InstanceName, dataSource?.Port, postfix);
@@ -317,7 +317,7 @@ private static SNITCPHandle CreateTcpHandle(
{
try
{
- port = isAdminConnection ?
+ details.ResolvedPort = port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference);
}
@@ -436,6 +436,11 @@ internal enum Protocol { TCP, NP, None, Admin };
///
internal int Port { get; private set; } = -1;
+ ///
+ /// The port resolved by SSRP when InstanceName is specified
+ ///
+ internal int ResolvedPort { get; set; } = -1;
+
///
/// Provides the inferred Instance Name from Server Data Source
///
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
index 07f866f98b..e167a264a2 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
@@ -1073,6 +1073,21 @@ public static string GetMachineFQDN(string hostname)
return fqdn.ToString();
}
+ public static bool IsNotLocalhost()
+ {
+ // get the tcp connection string
+ SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
+
+ string hostname = "";
+
+ // parse the datasource
+ ParseDataSource(builder.DataSource, out hostname, out _, out _);
+
+ // hostname must not be localhost, ., 127.0.0.1 nor ::1
+ return !(new string[] { "localhost", ".", "127.0.0.1", "::1" }).Contains(hostname.ToLowerInvariant());
+
+ }
+
private static bool RunningAsUWPApp()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs
index 311ce5fce5..087b44d964 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs
@@ -5,6 +5,7 @@
using System;
using System.Net;
using System.Net.Sockets;
+using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using Xunit;
@@ -83,6 +84,138 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}
}
+ // Note: This Unit test was tested in a domain-joined VM connecting to a remote
+ // SQL Server using Kerberos in the same domain.
+ [ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
+ [ConditionalFact(nameof(IsKerberos))]
+ public static void PortNumberInSPNTest()
+ {
+ string connStr = DataTestUtility.TCPConnectionString;
+ // If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
+ if (DataTestUtility.IsIntegratedSecuritySetup())
+ {
+ string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" };
+ connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true";
+ }
+
+ SqlConnectionStringBuilder builder = new(connStr);
+
+ Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name");
+
+ bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName);
+ Assert.True(condition, "Browser service is not running or instance name is invalid");
+
+ if (condition)
+ {
+ using SqlConnection connection = new(builder.ConnectionString);
+ connection.Open();
+ using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
+ using SqlDataReader reader = command.ExecuteReader();
+ Assert.True(reader.Read(), "Expected to receive one row data");
+ Assert.Equal("KERBEROS", reader.GetString(0));
+ int localTcpPort = reader.GetInt32(1);
+
+ int spnPort = -1;
+ string spnInfo = GetSPNInfo(builder.DataSource, out spnPort);
+
+ // sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
+ Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
+ // the local_tcp_port should be the same as the inferred SPN port from instance name
+ Assert.Equal(localTcpPort, spnPort);
+ }
+ }
+
+ private static string GetSPNInfo(string datasource, out int out_port)
+ {
+ Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));
+
+ // Get all required types using reflection
+ Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
+ Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
+ Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
+ Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");
+
+ // Used in Datasource constructor param type array
+ Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) };
+
+ // Used in GetSqlServerSPNs function param types array
+ Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) };
+
+ // GetPortByInstanceName parameters array
+ Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) };
+
+ // TimeoutTimer.StartSecondsTimeout params
+ Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) };
+
+ // Get all types constructors
+ ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
+ ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
+ ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
+ ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
+
+ // Instantiate SNIProxy
+ object sniProxy = sniProxyCtor.Invoke(new object[] { });
+
+ // Instantiate datasource
+ object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource });
+
+ // Instantiate SSRP
+ object ssrp = SSRPCtor.Invoke(new object[] { });
+
+ // Instantiate TimeoutTimer
+ object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });
+
+ // Get TimeoutTimer.StartSecondsTimeout Method
+ MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
+ // Create a timeoutTimer that expires in 30 seconds
+ timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 });
+
+ // Parse the datasource to separate the server name and instance name
+ MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
+ object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource });
+
+ // Get the GetPortByInstanceName method of SSRP
+ MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);
+
+ // Get the server name
+ PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
+ string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();
+
+ // Get the instance name
+ PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
+ string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();
+
+ // Get the port number using the GetPortByInstanceName method of SSRP
+ object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 });
+
+ // Set the resolved port property of datasource
+ PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
+ resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);
+
+ // Prepare the GetSqlServerSPNs method
+ string serverSPN = "";
+ MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);
+
+ // Finally call GetSqlServerSPNs
+ byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });
+
+ // Example result: MSSQLSvc/machine.domain.tld:port"
+ string spnInfo = Encoding.Unicode.GetString(result[0]);
+
+ out_port = (int)port;
+
+ return spnInfo;
+ }
+
+ private static bool IsKerberos()
+ {
+ return (DataTestUtility.AreConnStringsSetup()
+ && DataTestUtility.IsNotLocalhost()
+ && DataTestUtility.IsKerberosTest
+ && DataTestUtility.IsNotAzureServer()
+ && DataTestUtility.IsNotAzureSynapse());
+ }
+
private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;