From 037a11d74fced7256c97a729e249f5649de9ae1e Mon Sep 17 00:00:00 2001 From: v-reye Date: Mon, 24 Sep 2018 17:07:26 -0700 Subject: [PATCH] Fixed a bug where calling length() after obtaining a stream would close the stream for Clobs/NClobs (#799) Addresses #788. Also addressed an issue where varchar(max)/Clob objects were always being encoded to UTF-16LE instead of using the Collation specified in the Clob object. List of Changes (the following applies to NClob as well): 1. calling Clob.length() no longer attempts to load the stream into a string. This fixes a null pointer issue as well as length() closing user streams. 2. added streaming capabilities for Clob.getAsciiStream(). NClobs will always have a non-streaming implementation for getAsciiStream(). 3. Clobs no longer default to UTF-16LE. Clobs now respect the collation from the SQL Server. NClobs remain unchanged, will always be UTF-16LE encoding. --- .../sqlserver/jdbc/PLPInputStream.java | 3 - .../sqlserver/jdbc/SQLServerClob.java | 50 +++- .../sqlserver/jdbc/SQLServerNClob.java | 9 +- .../sqlserver/jdbc/SimpleInputStream.java | 6 +- .../jdbc/unit/lobs/LobsStreamingTest.java | 276 ++++++++++++++++++ 5 files changed, 322 insertions(+), 22 deletions(-) create mode 100644 src/test/java/com/microsoft/sqlserver/jdbc/unit/lobs/LobsStreamingTest.java diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/PLPInputStream.java b/src/main/java/com/microsoft/sqlserver/jdbc/PLPInputStream.java index 0b4563e0b..8a502a370 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/PLPInputStream.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/PLPInputStream.java @@ -25,9 +25,6 @@ class PLPInputStream extends BaseInputStream { static final int PLP_TERMINATOR = 0x00000000; private final static byte[] EMPTY_PLP_BYTES = new byte[0]; - // Stated length of the PLP stream payload; -1 if unknown length. - int payloadLength; - private static final int PLP_EOS = -1; private int currentChunkRemain; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerClob.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerClob.java index b59fd9062..79df3191c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerClob.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerClob.java @@ -5,11 +5,9 @@ package com.microsoft.sqlserver.jdbc; -import static java.nio.charset.StandardCharsets.US_ASCII; -import static java.nio.charset.StandardCharsets.UTF_16LE; - import java.io.BufferedInputStream; import java.io.BufferedReader; +import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; @@ -20,6 +18,7 @@ import java.io.StringReader; import java.io.UnsupportedEncodingException; import java.io.Writer; +import java.nio.charset.Charset; import java.sql.Clob; import java.sql.SQLException; import java.text.MessageFormat; @@ -152,7 +151,7 @@ abstract class SQLServerClobBase extends SQLServerLob implements Serializable { // The value of the CLOB that this Clob object represents. // This value is never null unless/until the free() method is called. - private String value; + protected String value; private final SQLCollation sqlCollation; @@ -181,8 +180,9 @@ final public String toString() { // Unique id generator for each instance (used for logging). static private final AtomicInteger baseID = new AtomicInteger(0); + private Charset defaultCharset = null; + // Returns unique id for each instance. - private static int nextInstanceID() { return baseID.incrementAndGet(); } @@ -281,9 +281,19 @@ public InputStream getAsciiStream() throws SQLException { if (null != sqlCollation && !sqlCollation.supportsAsciiConversion()) DataTypes.throwConversionError(getDisplayClassName(), "AsciiStream"); - getStringFromStream(); - InputStream getterStream = new BufferedInputStream( - new ReaderInputStream(new StringReader(value), US_ASCII, value.length())); + // Need to use a BufferedInputStream since the stream returned by this method is assumed to support mark/reset + InputStream getterStream = null; + if (null == value && !activeStreams.isEmpty()) { + InputStream inputStream = (InputStream) activeStreams.get(0); + try { + inputStream.reset(); + } catch (IOException e) { + SQLServerException.makeFromDriverError(con, null, e.getMessage(), null, false); + } + getterStream = new BufferedInputStream(inputStream); + } else { + getterStream = new ByteArrayInputStream(value.getBytes(java.nio.charset.StandardCharsets.US_ASCII)); + } activeStreams.add(getterStream); return getterStream; } @@ -301,11 +311,17 @@ public Reader getCharacterStream() throws SQLException { Reader getterStream = null; if (null == value && !activeStreams.isEmpty()) { InputStream inputStream = (InputStream) activeStreams.get(0); - getterStream = new BufferedReader(new InputStreamReader(inputStream, UTF_16LE)); + try { + inputStream.reset(); + } catch (IOException e) { + SQLServerException.makeFromDriverError(con, null, e.getMessage(), null, false); + } + Charset cs = (defaultCharset == null) ? typeInfo.getCharset() : defaultCharset; + getterStream = new BufferedReader(new InputStreamReader(inputStream, cs)); } else { getterStream = new StringReader(value); - activeStreams.add(getterStream); } + activeStreams.add(getterStream); return getterStream; } @@ -381,9 +397,8 @@ public String getSubString(long pos, int length) throws SQLException { */ public long length() throws SQLException { checkClosed(); - - if (value == null && activeStreams.get(0) instanceof PLPInputStream) { - return (long) ((PLPInputStream) activeStreams.get(0)).payloadLength / 2; + if (null == value && activeStreams.get(0) instanceof BaseInputStream) { + return (long) ((BaseInputStream) activeStreams.get(0)).payloadLength; } return value.length(); } @@ -410,9 +425,10 @@ private void getStringFromStream() throws SQLServerException { try { stream.reset(); } catch (IOException e) { - throw new SQLServerException(e.getMessage(), null, 0, e); + SQLServerException.makeFromDriverError(con, null, e.getMessage(), null, false); } - value = new String((stream).getBytes(), typeInfo.getCharset()); + Charset cs = (defaultCharset == null) ? typeInfo.getCharset() : defaultCharset; + value = new String(stream.getBytes(), cs); } } @@ -661,6 +677,10 @@ public int setString(long pos, String str, int offset, int len) throws SQLExcept return len; } + + protected void setDefaultCharset(Charset c) { + this.defaultCharset = c; + } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerNClob.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerNClob.java index dfa9fb12d..c3338278b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerNClob.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerNClob.java @@ -32,10 +32,12 @@ public final class SQLServerNClob extends SQLServerClobBase implements NClob { SQLServerNClob(SQLServerConnection connection) { super(connection, "", connection.getDatabaseCollation(), logger, null); + this.setDefaultCharset(java.nio.charset.StandardCharsets.UTF_16LE); } SQLServerNClob(BaseInputStream stream, TypeInfo typeInfo) throws SQLServerException, UnsupportedEncodingException { super(null, stream, typeInfo.getSQLCollation(), logger, typeInfo); + this.setDefaultCharset(java.nio.charset.StandardCharsets.UTF_16LE); } @Override @@ -45,6 +47,9 @@ public void free() throws SQLException { @Override public InputStream getAsciiStream() throws SQLException { + // NClobs are mapped to Nvarchar(max), and are always UTF-16 encoded. This API expects a US_ASCII stream. + // It's not possible to modify the stream without loading it into memory. Users should use getCharacterStream. + this.fillFromStream(); return super.getAsciiStream(); } @@ -65,7 +70,9 @@ public String getSubString(long pos, int length) throws SQLException { @Override public long length() throws SQLException { - return super.length(); + // If streaming, every 2 bytes represents 1 character. If not, length() just returns string length + long length = super.length(); + return (null == value) ? length / 2 : length; } @Override diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SimpleInputStream.java b/src/main/java/com/microsoft/sqlserver/jdbc/SimpleInputStream.java index 2135858c1..fa22c96ca 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SimpleInputStream.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SimpleInputStream.java @@ -24,6 +24,9 @@ abstract class BaseInputStream extends InputStream { // Flag indicating whether the stream consumes and discards data as it reads it final boolean isStreaming; + + // Stated length of the payload + int payloadLength; /** Generate the logging ID */ private String parentLoggingInfo = ""; @@ -131,9 +134,6 @@ void resetHelper() throws IOException { final class SimpleInputStream extends BaseInputStream { - // Stated length of the payload - private final int payloadLength; - /** * Initializes the input stream. */ diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/lobs/LobsStreamingTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/lobs/LobsStreamingTest.java new file mode 100644 index 000000000..55e8e1576 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/lobs/LobsStreamingTest.java @@ -0,0 +1,276 @@ +package com.microsoft.sqlserver.jdbc.unit.lobs; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.jdbc.RandomUtil; +import com.microsoft.sqlserver.jdbc.TestUtils; +import com.microsoft.sqlserver.testframework.AbstractTest; + + +@RunWith(JUnitPlatform.class) +public class LobsStreamingTest extends AbstractTest { + + private static final int LOB_ARRAY_SIZE = 500; // number of rows to insert into the table and compare + private static final int LOB_LENGTH_MIN = 8000; + private static final int LOB_LENGTH_MAX = 32000; + private static final String ASCII_CHARACTERS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()-=_+,./;'[]<>?:{}|`~\"\\"; + private static final String UNICODE_CHARACTERS = ASCII_CHARACTERS + + "Ǥ⚌c♮ƺåYèĢù⚏Ȓ★njäõpƸŃōoƝĤßuÙőƆE♹gLJÜŬȺDZ!Û☵ŦãǁĸNQŰǚǻTÖC]ǶýåÉbɉ☩=\\ȍáźŗǃĻýű☓☄¸T☑ö^k☏I:x☑⚀läiȉ☱☚⚅ǸǎãÂ"; + private static String tableName = null; + + private static enum Lob { + CLOB, + NCLOB + }; + + @BeforeEach + public void init() throws SQLException { + tableName = RandomUtil.getIdentifier("streamingTest"); + } + + private String getRandomString(int length, String validCharacters) { + StringBuilder salt = new StringBuilder(); + Random rnd = new Random(); + while (salt.length() < length) { + int index = (int) (rnd.nextFloat() * validCharacters.length()); + salt.append(validCharacters.charAt(index)); + } + String saltStr = salt.toString(); + return saltStr; + } + + // closing the scanner closes the Inputstream, and the driver needs the stream to fill LoBs + @SuppressWarnings("resource") + private String getStringFromInputStream(InputStream is) { + java.util.Scanner s = new java.util.Scanner(is, java.nio.charset.StandardCharsets.US_ASCII).useDelimiter("\\A"); + return s.hasNext() ? s.next() : ""; + } + + private String getStringFromReader(Reader r, long l) throws IOException { + // read the Reader contents into a buffer and return the complete string + final StringBuilder stringBuilder = new StringBuilder((int) l); + char[] buffer = new char[(int) l]; + int amountRead = -1; + while ((amountRead = r.read(buffer, 0, (int) l)) != -1) { + stringBuilder.append(buffer, 0, amountRead); + } + return stringBuilder.toString(); + } + + private void createLobTable(Statement stmt, String table, Lob l) throws SQLException { + String columnType = (l == Lob.CLOB) ? "varchar(max)" : "nvarchar(max)"; + stmt.execute("CREATE TABLE [" + table + "] (id int, lobValue " + columnType + ")"); + } + + private ArrayList createRandomStringArray(Lob l) { + String characterPool = (l == Lob.CLOB) ? ASCII_CHARACTERS : UNICODE_CHARACTERS; + ArrayList string_array = new ArrayList<>(); + IntStream.range(0, LOB_ARRAY_SIZE).forEach(i -> string_array.add( + getRandomString(ThreadLocalRandom.current().nextInt(LOB_LENGTH_MIN, LOB_LENGTH_MAX), characterPool))); + return string_array; + } + + private void insertData(Connection conn, String table, ArrayList lobs) throws SQLException { + try (PreparedStatement pstmt = conn.prepareStatement("INSERT INTO [" + table + "] VALUES(?,?)")) { + for (int i = 0; i < lobs.size(); i++) { + Clob c = conn.createClob(); + c.setString(1, lobs.get(i)); + pstmt.setInt(1, i); + pstmt.setClob(2, c); + pstmt.addBatch(); + } + pstmt.executeBatch(); + } + } + + @Test + @DisplayName("testLengthAfterStream") + public void testLengthAfterStream() throws SQLException, IOException { + try (Connection conn = DriverManager.getConnection(connectionString);) { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + createLobTable(stmt, tableName, Lob.CLOB); + ArrayList lob_data = createRandomStringArray(Lob.CLOB); + insertData(conn, tableName, lob_data); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM [" + tableName + "] ORDER BY id ASC")) { + while (rs.next()) { + Clob c = rs.getClob(2); + Reader r = c.getCharacterStream(); + long clobLength = c.length(); + String received = getStringFromReader(r, clobLength);// streaming string + c.free(); + assertEquals(lob_data.get(rs.getInt(1)), received);// compare streamed string to initial string + } + } + } finally { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + } + + @Test + @DisplayName("testClobsVarcharASCII") + public void testClobsVarcharASCII() throws SQLException { + try (Connection conn = DriverManager.getConnection(connectionString)) { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + createLobTable(stmt, tableName, Lob.CLOB); + ArrayList lob_data = createRandomStringArray(Lob.CLOB); + insertData(conn, tableName, lob_data); + + ArrayList lobsFromServer = new ArrayList<>(); + try (ResultSet rs = stmt.executeQuery("SELECT * FROM [" + tableName + "] ORDER BY id ASC")) { + while (rs.next()) { + int index = rs.getInt(1); + Clob c = rs.getClob(2); + assertEquals(c.length(), lob_data.get(index).length()); + lobsFromServer.add(c); + String received = getStringFromInputStream(c.getAsciiStream());// streaming string + assertEquals(lob_data.get(index), received);// compare streamed string to initial string + } + } + for (int i = 0; i < lob_data.size(); i++) { + String received = getStringFromInputStream(lobsFromServer.get(i).getAsciiStream());// non-streaming + // string + assertEquals(received, lob_data.get(i));// compare static string to streamed string + } + for (Clob c : lobsFromServer) { + c.free(); + } + } finally { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + } + + @Test + @DisplayName("testNClobsNVarcharASCII") + public void testNClobsVarcharASCII() throws SQLException, IOException { + try (Connection conn = DriverManager.getConnection(connectionString)) { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + createLobTable(stmt, tableName, Lob.NCLOB); + // Testing AsciiStream, use Clob string set or characters will be converted to '?' + ArrayList lob_data = createRandomStringArray(Lob.CLOB); + insertData(conn, tableName, lob_data); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM [" + tableName + "] ORDER BY id ASC")) { + while (rs.next()) { + int index = rs.getInt(1); + NClob c = rs.getNClob(2); + assertEquals(c.length(), lob_data.get(index).length()); + String received = getStringFromInputStream(c.getAsciiStream());// NClob AsciiStream is never + // streamed + c.free(); + assertEquals(lob_data.get(index), received);// compare string to initial string + } + } + } finally { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + } + + @Test + @DisplayName("testClobsVarcharCHARA") + public void testClobsVarcharCHARA() throws SQLException, IOException { + try (Connection conn = DriverManager.getConnection(connectionString)) { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + createLobTable(stmt, tableName, Lob.CLOB); + ArrayList lob_data = createRandomStringArray(Lob.CLOB); + insertData(conn, tableName, lob_data); + + ArrayList lobsFromServer = new ArrayList<>(); + try (ResultSet rs = stmt.executeQuery("SELECT * FROM [" + tableName + "] ORDER BY id ASC")) { + while (rs.next()) { + int index = rs.getInt(1); + Clob c = rs.getClob(2); + assertEquals(c.length(), lob_data.get(index).length()); + lobsFromServer.add(c); + String received = getStringFromReader(c.getCharacterStream(), c.length());// streaming string + assertEquals(lob_data.get(index), received);// compare streamed string to initial string + } + } + for (int i = 0; i < lob_data.size(); i++) { + String received = getStringFromReader(lobsFromServer.get(i).getCharacterStream(), + lobsFromServer.get(i).length());// non-streaming string + assertEquals(received, lob_data.get(i));// compare static string to streamed string + } + for (Clob c : lobsFromServer) { + c.free(); + } + } finally { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + } + + @Test + @DisplayName("testNClobsVarcharCHARA") + public void testNClobsVarcharCHARA() throws SQLException, IOException { + try (Connection conn = DriverManager.getConnection(connectionString)) { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + createLobTable(stmt, tableName, Lob.NCLOB); + ArrayList lob_data = createRandomStringArray(Lob.NCLOB); + insertData(conn, tableName, lob_data); + + ArrayList lobsFromServer = new ArrayList<>(); + try (ResultSet rs = stmt.executeQuery("SELECT * FROM [" + tableName + "] ORDER BY id ASC")) { + while (rs.next()) { + int index = rs.getInt(1); + NClob c = rs.getNClob(2); + assertEquals(c.length(), lob_data.get(index).length()); + lobsFromServer.add(c); + String received = getStringFromReader(c.getCharacterStream(), c.length());// streaming string + assertEquals(lob_data.get(index), received);// compare streamed string to initial string + } + } + for (int i = 0; i < lob_data.size(); i++) { + String received = getStringFromReader(lobsFromServer.get(i).getCharacterStream(), + lobsFromServer.get(i).length());// non-streaming string + assertEquals(received, lob_data.get(i));// compare static string to streamed string + } + for (Clob c : lobsFromServer) { + c.free(); + } + } finally { + try (Statement stmt = conn.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + } +}