Skip to content

Commit

Permalink
Merge pull request apache#61 from abenaru/close-on-completion
Browse files Browse the repository at this point in the history
Implements close on completion
  • Loading branch information
JrJuscelino authored and vfraga committed Mar 29, 2022
1 parent 4f651c5 commit 880fbbe
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@

package org.apache.arrow.driver.jdbc;

import static java.util.Objects.isNull;
import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue;
import static org.apache.arrow.util.Preconditions.checkNotNull;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.TimeZone;

import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.util.AutoCloseables;
import org.apache.calcite.avatica.AvaticaResultSet;
import org.apache.calcite.avatica.AvaticaStatement;
import org.apache.calcite.avatica.Meta;
Expand All @@ -49,27 +53,46 @@ public class ArrowFlightJdbcFlightStreamResultSet extends ArrowFlightJdbcVectorS
super(statement, state, signature, resultSetMetaData, timeZone, firstFrame);
}

@Override
protected AvaticaResultSet execute() throws SQLException {
protected FlightStreamQueue getFlightStreamQueue() {
return flightStreamQueue;
}

private void loadNewQueue() {
final Optional<FlightStreamQueue> oldQueue = Optional.ofNullable(getFlightStreamQueue());
try {
flightStreamQueue =
checkNotNull(createNewQueue(((ArrowFlightConnection) getStatement().getConnection()).getExecutorService()));
} catch (final SQLException e) {
throw new RuntimeException(e);
} finally {
oldQueue.ifPresent(AutoCloseables::closeNoChecked);
}
}

public FlightStream getCurrentFlightStream() {
return currentFlightStream;
}

private void loadNewFlightStream() {
final Optional<FlightStream> oldQueue = Optional.ofNullable(getCurrentFlightStream());
try {
final ArrowFlightConnection connection = (ArrowFlightConnection) statement.getConnection();
flightStreamQueue = new FlightStreamQueue(connection.getExecutorService());

final List<FlightStream> flightStreams = connection
.getClient()
.getFlightStreams(signature.sql);

flightStreams.forEach(flightStreamQueue::enqueue);

currentFlightStream = flightStreamQueue.next();
final VectorSchemaRoot root = currentFlightStream.getRoot();
execute(root);
} catch (SQLException e) {
throw e;
} catch (Exception e) {
throw new SQLException(e);
this.currentFlightStream = checkNotNull(getFlightStreamQueue().next());
} catch (final Exception e) {
throw new RuntimeException(e);
} finally {
oldQueue.ifPresent(AutoCloseables::closeNoChecked);
}
}

@Override
protected AvaticaResultSet execute() throws SQLException {
loadNewQueue();
getFlightStreamQueue().enqueue(
((ArrowFlightConnection) getStatement().getConnection())
.getClient().lazilyGetFlightStreams(signature.sql));
loadNewFlightStream();
// Ownership of the root will be passed onto the cursor.
execute(currentFlightStream.getRoot());
return this;
}

Expand All @@ -96,7 +119,11 @@ public boolean next() throws SQLException {
}

flightStreamQueue.enqueue(currentFlightStream);
currentFlightStream = flightStreamQueue.next();
try {
currentFlightStream = flightStreamQueue.next();
} catch (final Exception e) {
throw new SQLException(e);
}

if (currentFlightStream != null) {
execute(currentFlightStream.getRoot());
Expand All @@ -112,17 +139,13 @@ public boolean next() throws SQLException {
}

@Override
public void close() {
super.close();

public synchronized void close() {
try {
if (this.currentFlightStream != null) {
this.currentFlightStream.close();
}

flightStreamQueue.close();
} catch (Exception e) {
AutoCloseables.close(flightStreamQueue, isNull(currentFlightStream) ? null : currentFlightStream);
} catch (final Exception e) {
throw new RuntimeException(e);
} finally {
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@

package org.apache.arrow.driver.jdbc;

import static java.util.Objects.isNull;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TimeZone;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.arrow.driver.jdbc.utils.SqlTypes;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -78,6 +83,25 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(Vect
return resultSet;
}

private static List<ColumnMetaData> convertArrowFieldsToColumnMetaDataList(List<Field> fields) {
return Stream.iterate(0, Math::incrementExact).limit(fields.size())
.map(index -> {
Field field = fields.get(index);
ArrowType.ArrowTypeID fieldTypeId = field.getType().getTypeID();

Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder();
builder.setOrdinal(index);
builder.setColumnName(field.getName());

builder.setType(Common.AvaticaType.newBuilder()
.setId(SqlTypes.getSqlTypeIdFromArrowType(field.getType()))
.setName(fieldTypeId.name())
.build());

return ColumnMetaData.fromProto(builder.build());
}).collect(Collectors.toList());
}

@Override
protected AvaticaResultSet execute() throws SQLException {
throw new RuntimeException();
Expand All @@ -95,35 +119,31 @@ void execute(VectorSchemaRoot vectorSchemaRoot) {

@Override
public void close() {
if (this.statement != null) {
// An ArrowFlightResultSet will have a null statement when it is created by
// ArrowFlightResultSet#fromVectorSchemaRoot. In this case it must skip calling AvaticaResultSet#close,
// as it expects that statement is not null
super.close();
final Set<Exception> exceptions = new HashSet<>();
try {
if (isClosed()) {
return;
}
} catch (final SQLException e) {
exceptions.add(e);
}

if (this.vectorSchemaRoot != null) {
this.vectorSchemaRoot.close();
if (!isNull(vectorSchemaRoot)) {
try {
AutoCloseables.close(vectorSchemaRoot);
} catch (final Exception e) {
exceptions.add(e);
}
}
}

private static List<ColumnMetaData> convertArrowFieldsToColumnMetaDataList(List<Field> fields) {
return Stream.iterate(0, Math::incrementExact).limit(fields.size())
.map(index -> {
Field field = fields.get(index);
ArrowType.ArrowTypeID fieldTypeId = field.getType().getTypeID();

Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder();
builder.setOrdinal(index);
builder.setColumnName(field.getName());

builder.setType(Common.AvaticaType.newBuilder()
.setId(SqlTypes.getSqlTypeIdFromArrowType(field.getType()))
.setName(fieldTypeId.name())
.build());

return ColumnMetaData.fromProto(builder.build());
}).collect(Collectors.toList());
if (!isNull(statement)) {
try {
super.close();
} catch (final Exception e) {
exceptions.add(e);
}
}
exceptions.parallelStream().forEach(e -> {
throw new RuntimeException(e);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,31 @@

package org.apache.arrow.driver.jdbc.client;

import static java.util.Collections.synchronizedSet;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.Set;
import java.util.stream.Stream;

import javax.annotation.Nullable;

import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClient.Builder;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.HeaderCallOption;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware.Factory;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
Expand All @@ -49,8 +52,7 @@
*/
public class ArrowFlightClientHandler implements FlightClientHandler {

private final Deque<AutoCloseable> resources =
new ArrayDeque<>();
private final Set<AutoCloseable> resources = synchronizedSet(new HashSet<>());
private final FlightClient client;

@Nullable
Expand Down Expand Up @@ -122,16 +124,9 @@ protected FlightInfo getInfo(final String query) {
}

@Override
public List<FlightStream> getFlightStreams(final String query) {
final FlightInfo flightInfo = getInfo(query);
final List<FlightEndpoint> endpoints = flightInfo.getEndpoints();

final List<FlightStream> streams =
endpoints.stream().map(flightEndpoint -> client.getStream(flightEndpoint.getTicket(), token))
.collect(Collectors.toList());
streams.forEach(resources::addFirst);

return streams;
public Stream<FlightStream> lazilyGetFlightStreams(final String query) {
final List<FlightEndpoint> endpoints = getInfo(query).getEndpoints();
return endpoints.stream().map(flightEndpoint -> client.getStream(flightEndpoint.getTicket(), token));
}

@Override
Expand Down Expand Up @@ -175,7 +170,7 @@ public static final ArrowFlightClientHandler getClient(
* Do NOT resort to creating labels and breaking from them! A better
* alternative would be splitting this method into smaller ones.
*/
final FlightClient.Builder builder = FlightClient.builder()
final Builder builder = FlightClient.builder()
.allocator(allocator);

ArrowFlightClientHandler handler;
Expand Down Expand Up @@ -205,7 +200,7 @@ public static final ArrowFlightClientHandler getClient(
// Build an unauthenticated client.
handler = new ArrowFlightClientHandler(client, properties);
} else {
final ClientIncomingAuthHeaderMiddleware.Factory factory = new ClientIncomingAuthHeaderMiddleware.Factory(
final Factory factory = new Factory(
new ClientBearerHeaderHandler());

builder.intercept(factory);
Expand All @@ -218,7 +213,7 @@ public static final ArrowFlightClientHandler getClient(
properties);
}

handler.resources.addLast(client);
handler.resources.add(client);
return handler;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.arrow.driver.jdbc.client;

import java.util.List;
import static java.util.stream.Collectors.toList;

import java.util.Collection;
import java.util.stream.Stream;

import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
Expand All @@ -30,10 +33,21 @@ public interface FlightClientHandler extends AutoCloseable {

/**
* Makes an RPC "getStream" request based on the provided {@link FlightInfo}
* object. Retrieves result of the query previously prepared with "getInfo."
* object. Lazily retrieves result of the query previously prepared with "getInfo."
*
* @param query The query.
* @return a {@code FlightStream} of results.
*/
Stream<FlightStream> lazilyGetFlightStreams(String query);

/**
* Makes an RPC "getStream" request based on the provided {@link FlightInfo}
* object. Readily retrieves result of the query previously prepared with "getInfo."
*
* @param query The query.
* @return a {@code FlightStream} of results.
*/
List<FlightStream> getFlightStreams(String query);
default Collection<FlightStream> readilyGetFlightStreams(String query) {
return lazilyGetFlightStreams(query).collect(toList());
}
}
Loading

0 comments on commit 880fbbe

Please sign in to comment.