From 94c089e75842909efb2226edc248f41c3776a037 Mon Sep 17 00:00:00 2001 From: Aris Rellegue <134557572+arellegue@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:00:09 -0800 Subject: [PATCH] Fix | Fix unit test for SPN to include port number with Managed SNI (#2281) --- .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 6 +- .../ManualTests/DataCommon/DataTestUtility.cs | 6 +- .../SQL/InstanceNameTest/InstanceNameTest.cs | 180 +++++++++++------- 3 files changed, 113 insertions(+), 79 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 3df369a2f6..d39f382bd6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -653,10 +653,10 @@ private bool InferConnectionDetails() Port = port; } - // Instance Name Handling. Only if we found a '\' and we did not find a port in the Data Source - else if (backSlashIndex > -1) + // Instance Name Handling. + if (backSlashIndex > -1) { - // This means that there will not be any part separated by comma. + // This means that there is a part separated by '\' InstanceName = tokensByCommaAndSlash[1].Trim(); if (string.IsNullOrWhiteSpace(InstanceName)) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 5d6f9205b1..d588761bdb 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -1004,9 +1004,6 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i port = -1; instanceName = string.Empty; - if (dataSource.Contains(",") && dataSource.Contains("\\")) - return false; - if (dataSource.Contains(":")) { dataSource = dataSource.Substring(dataSource.IndexOf(":", StringComparison.Ordinal) + 1); @@ -1018,7 +1015,8 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i { return false; } - dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal) - 1); + // IndexOf is zero-based, no need to subtract one + dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal)); } if (dataSource.Contains("\\")) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index 087b44d964..205b9d33f1 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public static class InstanceNameTest { + private const char SemicolonSeparator = ';'; + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))] public static void ConnectToSQLWithInstanceNameTest() { @@ -84,138 +86,135 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove } } - // Note: This Unit test was tested in a domain-joined VM connecting to a remote - // SQL Server using Kerberos in the same domain. - [ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false - [ConditionalFact(nameof(IsKerberos))] - public static void PortNumberInSPNTest() +#if NETCOREAPP + [ConditionalFact(nameof(IsSPNPortNumberTestForTCP))] + public static void PortNumberInSPNTestForTCP() + { + string connectionString = DataTestUtility.TCPConnectionString; + SqlConnectionStringBuilder builder = new(connectionString); + + int port = GetNamedInstancePortNumberFromSqlBrowser(connectionString); + Assert.True(port > 0, "Named instance must have a valid port number."); + builder.DataSource = $"{builder.DataSource},{port}"; + + PortNumberInSPNTest(builder.ConnectionString, port); + } +#endif + + private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber) { - string connStr = DataTestUtility.TCPConnectionString; - // If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true if (DataTestUtility.IsIntegratedSecuritySetup()) { string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" }; - connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true"; + connectionString = DataTestUtility.RemoveKeysInConnStr(connectionString, removeKeys) + $"Integrated Security=true"; } - SqlConnectionStringBuilder builder = new(connStr); + SqlConnectionStringBuilder builder = new(connectionString); + + string hostname = ""; + string instanceName = ""; - Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name"); + DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName); - bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName); - Assert.True(condition, "Browser service is not running or instance name is invalid"); + Assert.False(string.IsNullOrEmpty(hostname), "Hostname must be included in the data source."); + Assert.False(string.IsNullOrEmpty(instanceName), "Instance name must be included in the data source."); - if (condition) + using (SqlConnection connection = new(builder.ConnectionString)) { - using SqlConnection connection = new(builder.ConnectionString); connection.Open(); - using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection); - using SqlDataReader reader = command.ExecuteReader(); - Assert.True(reader.Read(), "Expected to receive one row data"); - Assert.Equal("KERBEROS", reader.GetString(0)); - int localTcpPort = reader.GetInt32(1); - - int spnPort = -1; - string spnInfo = GetSPNInfo(builder.DataSource, out spnPort); - - // sample output to validate = MSSQLSvc/machine.domain.tld:spnPort" - Assert.Contains($"MSSQLSvc/{hostname}", spnInfo); - // the local_tcp_port should be the same as the inferred SPN port from instance name - Assert.Equal(localTcpPort, spnPort); + + string spnInfo = GetSPNInfo(builder.DataSource); + Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo); + + string[] spnStrs = spnInfo.Split(':'); + int portInSPN = 0; + if (spnStrs.Length > 1) + { + int.TryParse(spnStrs[1], out portInSPN); + } + Assert.Equal(expectedPortNumber, portInSPN); } } - private static string GetSPNInfo(string datasource, out int out_port) + private static string GetSPNInfo(string dataSource) { Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection)); - // Get all required types using reflection Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy"); Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP"); Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource"); Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer"); - // Used in Datasource constructor param type array Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) }; - // Used in GetSqlServerSPNs function param types array Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) }; - // GetPortByInstanceName parameters array Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) }; - // TimeoutTimer.StartSecondsTimeout params Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) }; - // Get all types constructors - ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); - ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); - ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); - ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + ConstructorInfo sniProxyConstructor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + ConstructorInfo SSRPConstructor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + ConstructorInfo dataSourceConstructor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); + ConstructorInfo timeoutTimerConstructor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); - // Instantiate SNIProxy - object sniProxy = sniProxyCtor.Invoke(new object[] { }); + object sniProxyObj = sniProxyConstructor.Invoke(new object[] { }); - // Instantiate datasource - object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource }); + object dataSourceObj = dataSourceConstructor.Invoke(new object[] { dataSource }); - // Instantiate SSRP - object ssrp = SSRPCtor.Invoke(new object[] { }); + object ssrpObj = SSRPConstructor.Invoke(new object[] { }); - // Instantiate TimeoutTimer - object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { }); + object timeoutTimerObj = timeoutTimerConstructor.Invoke(new object[] { }); - // Get TimeoutTimer.StartSecondsTimeout Method - MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null); - // Create a timeoutTimer that expires in 30 seconds - timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 }); + MethodInfo startSecondsTimeoutInfo = timeoutTimerObj.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null); - // Parse the datasource to separate the server name and instance name - MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); - object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource }); + timeoutTimerObj = startSecondsTimeoutInfo.Invoke(dataSourceObj, new object[] { 30 }); - // Get the GetPortByInstanceName method of SSRP - MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null); + MethodInfo parseServerNameInfo = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); + object dataSrcInfo = parseServerNameInfo.Invoke(dataSourceObj, new object[] { dataSource }); + + MethodInfo getPortByInstanceNameInfo = ssrpObj.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null); - // Get the server name PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString(); - // Get the instance name PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString(); - // Get the port number using the GetPortByInstanceName method of SSRP - object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 }); + object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 }); - // Set the resolved port property of datasource PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null); - // Prepare the GetSqlServerSPNs method string serverSPN = ""; - MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); + MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); - // Finally call GetSqlServerSPNs - byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN }); + byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN }); - // Example result: MSSQLSvc/machine.domain.tld:port" string spnInfo = Encoding.Unicode.GetString(result[0]); - out_port = (int)port; - return spnInfo; } - private static bool IsKerberos() + private static bool IsSPNPortNumberTestForTCP() { - return (DataTestUtility.AreConnStringsSetup() - && DataTestUtility.IsNotLocalhost() - && DataTestUtility.IsKerberosTest - && DataTestUtility.IsNotAzureServer() + return (IsInstanceNameValid(DataTestUtility.TCPConnectionString) + && DataTestUtility.IsUsingManagedSNI() + && DataTestUtility.IsNotAzureServer() && DataTestUtility.IsNotAzureSynapse()); } + private static bool IsInstanceNameValid(string connectionString) + { + string instanceName = ""; + + SqlConnectionStringBuilder builder = new(connectionString); + + bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out instanceName); + + return isDataSourceValid && !string.IsNullOrWhiteSpace(instanceName); + } + private static bool IsBrowserAlive(string browserHostname) { const byte ClntUcastEx = 0x03; @@ -231,6 +230,43 @@ private static bool IsValidInstance(string browserHostName, string instanceName) return response != null && response.Length > 0; } + private static int GetNamedInstancePortNumberFromSqlBrowser(string connectionString) + { + SqlConnectionStringBuilder builder = new(connectionString); + + string hostname = ""; + string instanceName = ""; + int port = 0; + + bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName); + Assert.True(isDataSourceValid, "DataSource is invalid"); + + bool isBrowserRunning = IsBrowserAlive(hostname); + Assert.True(isBrowserRunning, "Browser service is not running."); + + bool isInstanceExisting = IsValidInstance(hostname, instanceName); + Assert.True(isInstanceExisting, "Instance name is invalid."); + + if (isDataSourceValid && isBrowserRunning && isInstanceExisting) + { + byte[] request = CreateInstanceInfoRequest(instanceName); + byte[] response = QueryBrowser(hostname, request); + + string serverMessage = Encoding.ASCII.GetString(response, 3, response.Length - 3); + + string[] elements = serverMessage.Split(SemicolonSeparator); + int tcpIndex = Array.IndexOf(elements, "tcp"); + if (tcpIndex < 0 || tcpIndex == elements.Length - 1) + { + throw new SocketException(); + } + + port = (int)ushort.Parse(elements[tcpIndex + 1]); + } + + return port; + } + private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket) { const int DefaultBrowserPort = 1434;