diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java index 2b25c3de5..ef2210b8e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java @@ -173,7 +173,7 @@ public void testOpenResultSets() throws SQLException { } } } - + @Test public void testPooledConnection() throws SQLException { SQLServerConnectionPoolDataSource mds = new SQLServerConnectionPoolDataSource(); @@ -191,7 +191,7 @@ public void testPooledConnection() throws SQLException { fail(e.getMessage()); } } - + @Test public void testPooledConnectionDB() throws SQLException { SQLServerConnectionPoolDataSource mds = new SQLServerConnectionPoolDataSource(); @@ -229,31 +229,33 @@ public void testPooledConnectionDB() throws SQLException { TestUtils.dropDatabaseIfExists(newDBName, connectionString); } } - + @Test public void testPooledConnectionLang() throws SQLException { SQLServerConnectionPoolDataSource mds = new SQLServerConnectionPoolDataSource(); mds.setURL(connectionString); PooledConnection pooledConnection = mds.getPooledConnection(); String lang0 = null, lang1 = null; - + try (Connection c = pooledConnection.getConnection(); Statement s = c.createStatement()) { ResiliencyUtils.minimizeIdleNetworkTrackerPooledConnection(c); ResultSet rs = s.executeQuery("SELECT @@LANGUAGE;"); while (rs.next()) - lang0 = rs.getString(1); + lang0 = rs.getString(1); s.execute("SET LANGUAGE FRENCH;"); c.close(); - Connection c1 = pooledConnection.getConnection(); - Statement s1 = c1.createStatement(); - ResiliencyUtils.killConnection(c1, connectionString); - ResiliencyUtils.minimizeIdleNetworkTrackerPooledConnection(c1); - rs = s1.executeQuery("SELECT @@LANGUAGE;"); - while (rs.next()) - lang1 = rs.getString(1); - assertEquals(lang0, lang1); - s1.close(); - c1.close(); + try (Connection c1 = pooledConnection.getConnection(); Statement s1 = c1.createStatement()) { + ResiliencyUtils.killConnection(c1, connectionString); + ResiliencyUtils.minimizeIdleNetworkTrackerPooledConnection(c1); + rs = s1.executeQuery("SELECT @@LANGUAGE;"); + while (rs.next()) + lang1 = rs.getString(1); + assertEquals(lang0, lang1); + s1.close(); + c1.close(); + } finally { + rs.close(); + } } catch (SQLException e) { e.printStackTrace(); fail(e.getMessage()); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/ResiliencyUtils.java b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/ResiliencyUtils.java index a9952c65b..0e2e37e7c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/ResiliencyUtils.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/ResiliencyUtils.java @@ -21,6 +21,7 @@ import org.junit.Assert; +import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.jdbc.SQLServerConnectionPoolDataSource; @@ -77,7 +78,6 @@ void init() { } }, QUOTED_IDENTIFIER(ON_OFF), - REMOTE_PROC_TRANSACTIONS(ON_OFF), ROWCOUNT(null) { @Override void init() { @@ -177,27 +177,27 @@ public void setTimer(int millis) { this.sleepTime = millis; } } - - public static Connection getPooledConnection(String connectionString) throws SQLException { + + protected static Connection getPooledConnection(String connectionString) throws SQLException { SQLServerConnectionPoolDataSource mds = new SQLServerConnectionPoolDataSource(); mds.setURL(connectionString); PooledConnection pooledConnection = mds.getPooledConnection(); Connection c = pooledConnection.getConnection(); - + minimizeIdleNetworkTrackerPooledConnection(c); return c; } - public static Connection getConnection(String connectionString) throws SQLException { + protected static Connection getConnection(String connectionString) throws SQLException { Connection c = DriverManager.getConnection(connectionString); minimizeIdleNetworkTracker(c); return c; } - - public static void minimizeIdleNetworkTrackerPooledConnection(Connection c) { + + protected static void minimizeIdleNetworkTrackerPooledConnection(Connection c) { try { Field fieldsProxy[] = c.getClass().getDeclaredFields(); - for (Field f : fieldsProxy ) { + for (Field f : fieldsProxy) { if (f.getName() == "wrappedConnection") { f.setAccessible(true); Object wrappedConnection = f.get(c); @@ -206,7 +206,8 @@ public static void minimizeIdleNetworkTrackerPooledConnection(Connection c) { if (ff.getName() == "idleNetworkTracker") { ff.setAccessible(true); Object idleNetworkTracker = ff.get(wrappedConnection); - Method method = idleNetworkTracker.getClass().getDeclaredMethod("setMaxIdleMillis", int.class); + Method method = idleNetworkTracker.getClass().getDeclaredMethod("setMaxIdleMillis", + int.class); method.setAccessible(true); method.invoke(idleNetworkTracker, -1); break; @@ -218,8 +219,8 @@ public static void minimizeIdleNetworkTrackerPooledConnection(Connection c) { Assert.fail("Failed to setMaxIdleMillis in Connection's idleNetworkTracker: " + e.getMessage()); } } - - public static void minimizeIdleNetworkTracker(Connection c) { + + protected static void minimizeIdleNetworkTracker(Connection c) { try { Field fields[] = c.getClass().getSuperclass().getDeclaredFields(); for (Field f : fields) { @@ -237,11 +238,11 @@ public static void minimizeIdleNetworkTracker(Connection c) { } } - public static void killConnection(Connection c, String cString) throws SQLException { + protected static void killConnection(Connection c, String cString) throws SQLException { killConnection(getSessionId(c), cString); } - public static int getSessionId(Connection c) throws SQLException { + protected static int getSessionId(Connection c) throws SQLException { int sessionID = 0; try (Statement s = c.createStatement()) { try (ResultSet rs = s.executeQuery("SELECT @@SPID")) { @@ -253,7 +254,7 @@ public static int getSessionId(Connection c) throws SQLException { return sessionID; } - public static void killConnection(int sessionID, String cString) throws SQLException { + protected static void killConnection(int sessionID, String cString) throws SQLException { try (Connection c2 = DriverManager.getConnection(cString)) { try (Statement s = c2.createStatement()) { s.execute("KILL " + sessionID); @@ -262,8 +263,14 @@ public static void killConnection(int sessionID, String cString) throws SQLExcep } // uses reflection to "corrupt" a Connection's server target - public static void blockConnection(Connection c) throws SQLException { - Field fields[] = c.getClass().getSuperclass().getDeclaredFields(); + protected static void blockConnection(Connection c) throws SQLException { + Class cls = c.getClass(); + // SQLServerConnection43 is returend for java >=9 otherwise SQLServerConnection + if (cls != SQLServerConnection.class) { + cls = cls.getSuperclass(); + } + + Field fields[] = cls.getDeclaredFields(); for (Field f : fields) { if (f.getName() == "activeConnectionProperties" && Properties.class == f.getType()) { f.setAccessible(true); @@ -281,7 +288,7 @@ public static void blockConnection(Connection c) throws SQLException { Assert.fail("Failed to block connection."); } - public static Map getUserOptions(Connection c) throws SQLException { + protected static Map getUserOptions(Connection c) throws SQLException { Map options = new HashMap<>(); try (Statement stmt = c.createStatement()) { try (ResultSet rs = stmt.executeQuery("DBCC USEROPTIONS")) { @@ -295,7 +302,7 @@ public static Map getUserOptions(Connection c) throws SQLExcepti return options; } - public static void toggleRandomProperties(Connection c) throws SQLException { + protected static void toggleRandomProperties(Connection c) throws SQLException { try (Statement stmt = c.createStatement()) { for (USER_OPTIONS uo : USER_OPTIONS.values()) { stmt.execute("SET " + uo.toString() + " " + uo.getValue()); @@ -303,11 +310,11 @@ public static void toggleRandomProperties(Connection c) throws SQLException { } } - public static int getRandomInt(int min, int max) { + protected static int getRandomInt(int min, int max) { return ThreadLocalRandom.current().nextInt(min, max); } - public static String getRandomString(String pool, int length) { + protected static String getRandomString(String pool, int length) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < length; i++) { sb.append(String.valueOf(pool.charAt(getRandomInt(0, pool.length())))); @@ -315,7 +322,7 @@ public static String getRandomString(String pool, int length) { return sb.toString(); } - public static String setConnectionProps(String base, Map props) { + protected static String setConnectionProps(String base, Map props) { StringBuilder sb = new StringBuilder(); sb.append(base); props.forEach((k, v) -> sb.append(k).append("=").append(v).append(";"));