Skip to content

Commit

Permalink
Fix | Fix unit test for SPN to include port number with Managed SNI (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
arellegue authored Feb 26, 2024
1 parent 5cd9514 commit 94c089e
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,10 @@ private bool InferConnectionDetails()

Port = port;
}
// Instance Name Handling. Only if we found a '\' and we did not find a port in the Data Source
else if (backSlashIndex > -1)
// Instance Name Handling.
if (backSlashIndex > -1)
{
// This means that there will not be any part separated by comma.
// This means that there is a part separated by '\'
InstanceName = tokensByCommaAndSlash[1].Trim();

if (string.IsNullOrWhiteSpace(InstanceName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,9 +1004,6 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
port = -1;
instanceName = string.Empty;

if (dataSource.Contains(",") && dataSource.Contains("\\"))
return false;

if (dataSource.Contains(":"))
{
dataSource = dataSource.Substring(dataSource.IndexOf(":", StringComparison.Ordinal) + 1);
Expand All @@ -1018,7 +1015,8 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
{
return false;
}
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal) - 1);
// IndexOf is zero-based, no need to subtract one
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal));
}

if (dataSource.Contains("\\"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class InstanceNameTest
{
private const char SemicolonSeparator = ';';

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
Expand Down Expand Up @@ -84,138 +86,135 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}
}

// Note: This Unit test was tested in a domain-joined VM connecting to a remote
// SQL Server using Kerberos in the same domain.
[ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
[ConditionalFact(nameof(IsKerberos))]
public static void PortNumberInSPNTest()
#if NETCOREAPP
[ConditionalFact(nameof(IsSPNPortNumberTestForTCP))]
public static void PortNumberInSPNTestForTCP()
{
string connectionString = DataTestUtility.TCPConnectionString;
SqlConnectionStringBuilder builder = new(connectionString);

int port = GetNamedInstancePortNumberFromSqlBrowser(connectionString);
Assert.True(port > 0, "Named instance must have a valid port number.");
builder.DataSource = $"{builder.DataSource},{port}";

PortNumberInSPNTest(builder.ConnectionString, port);
}
#endif

private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber)
{
string connStr = DataTestUtility.TCPConnectionString;
// If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
if (DataTestUtility.IsIntegratedSecuritySetup())
{
string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" };
connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true";
connectionString = DataTestUtility.RemoveKeysInConnStr(connectionString, removeKeys) + $"Integrated Security=true";
}

SqlConnectionStringBuilder builder = new(connStr);
SqlConnectionStringBuilder builder = new(connectionString);

string hostname = "";
string instanceName = "";

Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name");
DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);

bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName);
Assert.True(condition, "Browser service is not running or instance name is invalid");
Assert.False(string.IsNullOrEmpty(hostname), "Hostname must be included in the data source.");
Assert.False(string.IsNullOrEmpty(instanceName), "Instance name must be included in the data source.");

if (condition)
using (SqlConnection connection = new(builder.ConnectionString))
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
using SqlDataReader reader = command.ExecuteReader();
Assert.True(reader.Read(), "Expected to receive one row data");
Assert.Equal("KERBEROS", reader.GetString(0));
int localTcpPort = reader.GetInt32(1);

int spnPort = -1;
string spnInfo = GetSPNInfo(builder.DataSource, out spnPort);

// sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
// the local_tcp_port should be the same as the inferred SPN port from instance name
Assert.Equal(localTcpPort, spnPort);

string spnInfo = GetSPNInfo(builder.DataSource);
Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo);

string[] spnStrs = spnInfo.Split(':');
int portInSPN = 0;
if (spnStrs.Length > 1)
{
int.TryParse(spnStrs[1], out portInSPN);
}
Assert.Equal(expectedPortNumber, portInSPN);
}
}

private static string GetSPNInfo(string datasource, out int out_port)
private static string GetSPNInfo(string dataSource)
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

// Get all required types using reflection
Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");

// Used in Datasource constructor param type array
Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) };

// Used in GetSqlServerSPNs function param types array
Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) };

// GetPortByInstanceName parameters array
Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) };

// TimeoutTimer.StartSecondsTimeout params
Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) };

// Get all types constructors
ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo sniProxyConstructor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPConstructor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo dataSourceConstructor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
ConstructorInfo timeoutTimerConstructor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

// Instantiate SNIProxy
object sniProxy = sniProxyCtor.Invoke(new object[] { });
object sniProxyObj = sniProxyConstructor.Invoke(new object[] { });

// Instantiate datasource
object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource });
object dataSourceObj = dataSourceConstructor.Invoke(new object[] { dataSource });

// Instantiate SSRP
object ssrp = SSRPCtor.Invoke(new object[] { });
object ssrpObj = SSRPConstructor.Invoke(new object[] { });

// Instantiate TimeoutTimer
object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });
object timeoutTimerObj = timeoutTimerConstructor.Invoke(new object[] { });

// Get TimeoutTimer.StartSecondsTimeout Method
MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
// Create a timeoutTimer that expires in 30 seconds
timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 });
MethodInfo startSecondsTimeoutInfo = timeoutTimerObj.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);

// Parse the datasource to separate the server name and instance name
MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource });
timeoutTimerObj = startSecondsTimeoutInfo.Invoke(dataSourceObj, new object[] { 30 });

// Get the GetPortByInstanceName method of SSRP
MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);
MethodInfo parseServerNameInfo = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
object dataSrcInfo = parseServerNameInfo.Invoke(dataSourceObj, new object[] { dataSource });

MethodInfo getPortByInstanceNameInfo = ssrpObj.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);

// Get the server name
PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();

// Get the instance name
PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();

// Get the port number using the GetPortByInstanceName method of SSRP
object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 });
object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 });

// Set the resolved port property of datasource
PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);

// Prepare the GetSqlServerSPNs method
string serverSPN = "";
MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);
MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);

// Finally call GetSqlServerSPNs
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN });

// Example result: MSSQLSvc/machine.domain.tld:port"
string spnInfo = Encoding.Unicode.GetString(result[0]);

out_port = (int)port;

return spnInfo;
}

private static bool IsKerberos()
private static bool IsSPNPortNumberTestForTCP()
{
return (DataTestUtility.AreConnStringsSetup()
&& DataTestUtility.IsNotLocalhost()
&& DataTestUtility.IsKerberosTest
&& DataTestUtility.IsNotAzureServer()
return (IsInstanceNameValid(DataTestUtility.TCPConnectionString)
&& DataTestUtility.IsUsingManagedSNI()
&& DataTestUtility.IsNotAzureServer()
&& DataTestUtility.IsNotAzureSynapse());
}

private static bool IsInstanceNameValid(string connectionString)
{
string instanceName = "";

SqlConnectionStringBuilder builder = new(connectionString);

bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out instanceName);

return isDataSourceValid && !string.IsNullOrWhiteSpace(instanceName);
}

private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;
Expand All @@ -231,6 +230,43 @@ private static bool IsValidInstance(string browserHostName, string instanceName)
return response != null && response.Length > 0;
}

private static int GetNamedInstancePortNumberFromSqlBrowser(string connectionString)
{
SqlConnectionStringBuilder builder = new(connectionString);

string hostname = "";
string instanceName = "";
int port = 0;

bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);
Assert.True(isDataSourceValid, "DataSource is invalid");

bool isBrowserRunning = IsBrowserAlive(hostname);
Assert.True(isBrowserRunning, "Browser service is not running.");

bool isInstanceExisting = IsValidInstance(hostname, instanceName);
Assert.True(isInstanceExisting, "Instance name is invalid.");

if (isDataSourceValid && isBrowserRunning && isInstanceExisting)
{
byte[] request = CreateInstanceInfoRequest(instanceName);
byte[] response = QueryBrowser(hostname, request);

string serverMessage = Encoding.ASCII.GetString(response, 3, response.Length - 3);

string[] elements = serverMessage.Split(SemicolonSeparator);
int tcpIndex = Array.IndexOf(elements, "tcp");
if (tcpIndex < 0 || tcpIndex == elements.Length - 1)
{
throw new SocketException();
}

port = (int)ushort.Parse(elements[tcpIndex + 1]);
}

return port;
}

private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket)
{
const int DefaultBrowserPort = 1434;
Expand Down

0 comments on commit 94c089e

Please sign in to comment.