Skip to content

Commit

Permalink
Merge pull request #1269 from tristaZero/dev
Browse files Browse the repository at this point in the history
Add AbstractStatementExecutor
  • Loading branch information
terrymanu authored Sep 17, 2018
2 parents 9363b4b + b61c580 commit d2d2598
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 271 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright 2016-2018 shardingsphere.io.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* </p>
*/

package io.shardingsphere.core.executor;

import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.executor.sql.execute.SQLExecuteCallback;
import io.shardingsphere.core.executor.sql.execute.SQLExecuteTemplate;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareTemplate;
import io.shardingsphere.core.jdbc.core.connection.ShardingConnection;
import lombok.AccessLevel;
import lombok.Getter;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

/**
* Abstract statement executor.
*
* @author panjuan
*/
@Getter(AccessLevel.PROTECTED)
public class AbstractStatementExecutor {

private final DatabaseType databaseType;

protected SQLType sqlType;

@Getter
private final int resultSetType;

@Getter
private final int resultSetConcurrency;

@Getter
private final int resultSetHoldability;

private final ShardingConnection connection;

private final SQLExecutePrepareTemplate sqlExecutePrepareTemplate;

private final SQLExecuteTemplate sqlExecuteTemplate;

private final Collection<Connection> connections = new LinkedList<>();

@Getter
private final List<List<Object>> parameterSets = new LinkedList<>();

@Getter
private final List<Statement> statements = new LinkedList<>();

@Getter
private final List<ResultSet> resultSets = new CopyOnWriteArrayList<>();

private final Collection<ShardingExecuteGroup<StatementExecuteUnit>> executeGroups = new LinkedList<>();

public AbstractStatementExecutor(final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final ShardingConnection shardingConnection) {
this.databaseType = shardingConnection.getShardingDataSource().getShardingContext().getDatabaseType();
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
this.connection = shardingConnection;
sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine());
sqlExecutePrepareTemplate = new SQLExecutePrepareTemplate(connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery());
}

@SuppressWarnings("unchecked")
protected <T> List<T> executeCallback(final SQLExecuteCallback<T> executeCallback) throws SQLException {
return sqlExecuteTemplate.executeGroup((Collection) executeGroups, executeCallback);
}

/**
* Clear data.
*
* @throws SQLException sql exception
*/
public void clear() throws SQLException {
clearStatements();
clearConnections();
statements.clear();
parameterSets.clear();
connections.clear();
resultSets.clear();
executeGroups.clear();
}

private void clearStatements() throws SQLException {
for (Statement each : getStatements()) {
each.close();
}
}

private void clearConnections() {
for (Connection each : connections) {
connection.release(each);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@
import com.google.common.collect.Lists;
import io.shardingsphere.core.constant.ConnectionMode;
import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.executor.sql.execute.SQLExecuteCallback;
import io.shardingsphere.core.executor.sql.execute.SQLExecuteTemplate;
import io.shardingsphere.core.executor.sql.execute.threadlocal.ExecutorDataMap;
import io.shardingsphere.core.executor.sql.execute.threadlocal.ExecutorExceptionHandler;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareCallback;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareTemplate;
import io.shardingsphere.core.jdbc.core.connection.ShardingConnection;
import io.shardingsphere.core.routing.BatchRouteUnit;
import io.shardingsphere.core.routing.RouteUnit;
Expand All @@ -39,7 +36,6 @@

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
Expand All @@ -48,7 +44,6 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CopyOnWriteArrayList;

/**
* Prepared statement executor to process add batch.
Expand All @@ -57,60 +52,32 @@
* @author maxiaoguang
* @author panjuan
*/
public final class BatchPreparedStatementExecutor {

private final DatabaseType databaseType;

private SQLType sqlType;

private int batchCount;

private final int resultSetType;

private final int resultSetConcurrency;

private final int resultSetHoldability;

private final boolean returnGeneratedKeys;

private final ShardingConnection connection;
public final class BatchPreparedStatementExecutor extends AbstractStatementExecutor {

private final Collection<BatchRouteUnit> routeUnits = new LinkedList<>();

private final SQLExecuteTemplate sqlExecuteTemplate;

private final SQLExecutePrepareTemplate sqlExecutePrepareTemplate;

@Getter
private final List<ResultSet> resultSets = new CopyOnWriteArrayList<>();

private final Collection<Connection> connections = new LinkedList<>();
private final boolean returnGeneratedKeys;

private final Collection<ShardingExecuteGroup<StatementExecuteUnit>> executeGroups = new LinkedList<>();
private int batchCount;

public BatchPreparedStatementExecutor(final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys,
final ShardingConnection shardingConnection) {
this.databaseType = shardingConnection.getShardingDataSource().getShardingContext().getDatabaseType();
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
super(resultSetType, resultSetConcurrency, resultSetHoldability, shardingConnection);
this.returnGeneratedKeys = returnGeneratedKeys;
this.connection = shardingConnection;
sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine());
sqlExecutePrepareTemplate = new SQLExecutePrepareTemplate(connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery());
}

/**
* Init executor.
* Initialize executor.
*
* @exception SQLException sql exception
* @throws SQLException SQL exception
*/
public void init() throws SQLException {
executeGroups.addAll(obtainExecuteGroups(routeUnits));
getExecuteGroups().addAll(obtainExecuteGroups(routeUnits));
}

private Collection<ShardingExecuteGroup<StatementExecuteUnit>> obtainExecuteGroups(final Collection<BatchRouteUnit> routeUnits) throws SQLException {
return sqlExecutePrepareTemplate.getExecuteUnitGroups(Lists.transform(new ArrayList<>(routeUnits), new Function<BatchRouteUnit, RouteUnit>() {
return getSqlExecutePrepareTemplate().getExecuteUnitGroups(Lists.transform(new ArrayList<>(routeUnits), new Function<BatchRouteUnit, RouteUnit>() {

@Override
public RouteUnit apply(final BatchRouteUnit input) {
Expand All @@ -120,8 +87,8 @@ public RouteUnit apply(final BatchRouteUnit input) {

@Override
public Connection getConnection(final String dataSourceName) throws SQLException {
Connection conn = connection.getNewConnection(dataSourceName);
connections.add(conn);
Connection conn = BatchPreparedStatementExecutor.super.getConnection().getNewConnection(dataSourceName);
getConnections().add(conn);
return conn;
}

Expand All @@ -134,7 +101,7 @@ public StatementExecuteUnit createStatementExecuteUnit(final Connection connecti
}

private PreparedStatement createPreparedStatement(final Connection connection, final String sql) throws SQLException {
return returnGeneratedKeys ? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS) : connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
return returnGeneratedKeys ? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS) : connection.prepareStatement(sql, getResultSetType(), getResultSetConcurrency(), getResultSetHoldability());
}

/**
Expand Down Expand Up @@ -193,7 +160,7 @@ private void handleNewRouteUnits(final Collection<BatchRouteUnit> newRouteUnits)
public int[] executeBatch() throws SQLException {
final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
final Map<String, Object> dataMap = ExecutorDataMap.getDataMap();
SQLExecuteCallback<int[]> callback = new SQLExecuteCallback<int[]>(databaseType, sqlType, isExceptionThrown, dataMap) {
SQLExecuteCallback<int[]> callback = new SQLExecuteCallback<int[]>(getDatabaseType(), sqlType, isExceptionThrown, dataMap) {

@Override
protected int[] executeSQL(final StatementExecuteUnit statementExecuteUnit) throws SQLException {
Expand All @@ -209,7 +176,7 @@ private int[] accumulate(final List<int[]> results) {
for (BatchRouteUnit each : routeUnits) {
for (Entry<Integer, Integer> entry : each.getJdbcAndActualAddBatchCallTimesMap().entrySet()) {
int value = null == results.get(count) ? 0 : results.get(count)[entry.getValue()];
if (DatabaseType.Oracle == databaseType) {
if (DatabaseType.Oracle == getDatabaseType()) {
result[entry.getKey()] = value;
} else {
result[entry.getKey()] += value;
Expand All @@ -220,19 +187,15 @@ private int[] accumulate(final List<int[]> results) {
return result;
}

@SuppressWarnings("unchecked")
private <T> List<T> executeCallback(final SQLExecuteCallback<T> executeCallback) throws SQLException {
return sqlExecuteTemplate.executeGroup((Collection) executeGroups, executeCallback);
}

/**
* Get statements.
*
* @return statements
*/
@Override
public List<Statement> getStatements() {
List<Statement> result = new LinkedList<>();
for (ShardingExecuteGroup<StatementExecuteUnit> each : executeGroups) {
for (ShardingExecuteGroup<StatementExecuteUnit> each : getExecuteGroups()) {
result.addAll(Lists.transform(each.getInputs(), new Function<StatementExecuteUnit, Statement>() {

@Override
Expand All @@ -253,7 +216,7 @@ public Statement apply(final StatementExecuteUnit input) {
public List<List<Object>> getParameterSet(final Statement statement) {
Optional<StatementExecuteUnit> target;
List<List<Object>> result = new LinkedList<>();
for (ShardingExecuteGroup<StatementExecuteUnit> each : executeGroups) {
for (ShardingExecuteGroup<StatementExecuteUnit> each : getExecuteGroups()) {
target = Iterators.tryFind(each.getInputs().iterator(), new Predicate<StatementExecuteUnit>() {
@Override
public boolean apply(final StatementExecuteUnit input) {
Expand All @@ -268,31 +231,11 @@ public boolean apply(final StatementExecuteUnit input) {
return result;
}

/**
* Clear data.
*
* @throws SQLException sql exception
*/
@Override
public void clear() throws SQLException {
clearStatements();
clearConnections();
super.clear();
batchCount = 0;
connections.clear();
routeUnits.clear();
resultSets.clear();
executeGroups.clear();
}

private void clearStatements() throws SQLException {
for (Statement each : getStatements()) {
each.close();
}
}

private void clearConnections() {
for (Connection each : connections) {
connection.release(each);
}
}
}

Expand Down
Loading

0 comments on commit d2d2598

Please sign in to comment.