From 091b28aefa22faf63f47aa99199dcaf5b346984f Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 5 Oct 2021 15:45:54 -0300 Subject: [PATCH] Fix Schema serialization and deserialization on Flight SQL methods --- .../arrow/flight/sql/FlightSqlClient.java | 25 +++++++++++++------ .../apache/arrow/flight/TestFlightSql.java | 19 +++++++++----- .../flight/sql/example/FlightSqlExample.java | 23 +++++++++++++---- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index ebe635e61830c..c1b820b85c704 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -31,6 +31,9 @@ import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; +import java.io.IOException; +import java.nio.channels.Channels; import java.sql.SQLException; import java.util.Arrays; import java.util.Collections; @@ -40,7 +43,6 @@ import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; -import org.apache.arrow.flatbuf.Message; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.CallStatus; @@ -60,6 +62,7 @@ import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -452,9 +455,7 @@ public void clearParameters() { public Schema getResultSetSchema() { if (resultSetSchema == null) { final ByteString bytes = preparedStatementResult.getDatasetSchema(); - resultSetSchema = bytes.isEmpty() ? - new Schema(Collections.emptyList()) : - MessageSerializer.deserializeSchema(Message.getRootAsMessage(bytes.asReadOnlyByteBuffer())); + resultSetSchema = deserializeSchema(bytes); } return resultSetSchema; } @@ -467,13 +468,23 @@ public Schema getResultSetSchema() { public Schema getParameterSchema() { if (parameterSchema == null) { final ByteString bytes = preparedStatementResult.getParameterSchema(); - parameterSchema = bytes.isEmpty() ? - new Schema(Collections.emptyList()) : - MessageSerializer.deserializeSchema(Message.getRootAsMessage(bytes.asReadOnlyByteBuffer())); + parameterSchema = deserializeSchema(bytes); } return parameterSchema; } + private Schema deserializeSchema(final ByteString bytes) { + try { + return bytes.isEmpty() ? + new Schema(Collections.emptyList()) : + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel( + new ByteBufferBackedInputStream(bytes.asReadOnlyByteBuffer())))); + } catch (final IOException e) { + throw new RuntimeException("Failed to deserialize schema", e); + } + } + /** * Executes the prepared statement query on the server. * diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 5ecdf41789e60..e99adf8d9a81c 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -27,7 +27,9 @@ import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.nullValue; -import java.nio.ByteBuffer; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; import java.sql.SQLException; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -36,7 +38,6 @@ import java.util.Objects; import java.util.stream.IntStream; -import org.apache.arrow.flatbuf.Message; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; import org.apache.arrow.flight.sql.FlightSqlProducer; @@ -53,6 +54,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -606,10 +608,15 @@ List> getResults(FlightStream stream) { final VarBinaryVector varbinaryVector = (VarBinaryVector) fieldVector; for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { final byte[] data = varbinaryVector.getObject(rowIndex); - final String output = - isNull(data) ? - null : - MessageSerializer.deserializeSchema(Message.getRootAsMessage(ByteBuffer.wrap(data))).toJson(); + final String output; + try { + output = isNull(data) ? + null : + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data)))).toJson(); + } catch (final IOException e) { + throw new RuntimeException("Failed to deserialize schema", e); + } results.get(rowIndex).add(output); } } else if (fieldVector instanceof DenseUnionVector) { diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 3c6b9ad2af683..ecbf855810826 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -30,9 +30,12 @@ import static org.apache.arrow.util.Preconditions.checkState; import static org.slf4j.LoggerFactory.getLogger; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.NoSuchFileException; @@ -137,6 +140,7 @@ import org.apache.arrow.vector.holders.NullableBitHolder; import org.apache.arrow.vector.holders.NullableUInt4Holder; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; @@ -541,7 +545,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet final String tableName = tableNameVector.getObject(index).toString(); final Schema schema = new Schema(tableToFields.get(tableName)); saveToVector( - copyFrom(MessageSerializer.serializeMetadata(schema, DEFAULT_OPTION)).toByteArray(), + copyFrom(serializeMetadata(schema)).toByteArray(), tableSchemaVector, index); } } @@ -554,6 +558,17 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet return new VectorSchemaRoot(vectors); } + private static ByteBuffer serializeMetadata(final Schema schema) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); + + return ByteBuffer.wrap(outputStream.toByteArray()); + } catch (final IOException e) { + throw new RuntimeException("Failed to serialize schema", e); + } + } + private static VectorSchemaRoot getSqlInfoRoot(final DatabaseMetaData metaData, final BufferAllocator allocator, final Iterable requestedInfo) throws SQLException { return getSqlInfoRoot(metaData, allocator, stream(requestedInfo.spliterator(), false).toArray(Integer[]::new)); @@ -777,12 +792,10 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r final ByteString bytes = isNull(metaData) ? ByteString.EMPTY : ByteString.copyFrom( - MessageSerializer.serializeMetadata( - jdbcToArrowSchema(metaData, DEFAULT_CALENDAR), - DEFAULT_OPTION)); + serializeMetadata(jdbcToArrowSchema(metaData, DEFAULT_CALENDAR))); final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder() .setDatasetSchema(bytes) - .setParameterSchema(copyFrom(MessageSerializer.serializeMetadata(parameterSchema, DEFAULT_OPTION))) + .setParameterSchema(copyFrom(serializeMetadata(parameterSchema))) .setPreparedStatementHandle(preparedStatementHandle) .build(); listener.onNext(new Result(pack(result).toByteArray()));