Skip to content

Commit

Permalink
Support distributed transactions across multiple logical database(#19… (
Browse files Browse the repository at this point in the history
#20114)

* Support distributed transactions across multiple logical database(#19894)

* Fix test case

* Generate data source name

* Add test case

* Remove final

* JDBC does not support operations across multiple logical databases in transaction

* Fix equals usage
  • Loading branch information
FlyingZC authored Aug 16, 2022
1 parent 3bc99d5 commit e16dff4
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.context.kernel.KernelProcessor;
Expand Down Expand Up @@ -78,6 +79,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement;
import org.apache.shardingsphere.traffic.engine.TrafficEngine;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.transaction.TransactionHolder;

import java.sql.Connection;
import java.sql.ResultSet;
Expand Down Expand Up @@ -156,6 +158,7 @@ public ResultSet executeQuery(final String sql) throws SQLException {
ResultSet result;
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand Down Expand Up @@ -235,6 +238,7 @@ private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriver
public int executeUpdate(final String sql) throws SQLException {
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand Down Expand Up @@ -263,6 +267,7 @@ public int executeUpdate(final String sql, final int autoGeneratedKeys) throws S
}
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand All @@ -289,6 +294,7 @@ public int executeUpdate(final String sql, final int[] columnIndexes) throws SQL
returnGeneratedKeys = true;
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand All @@ -315,6 +321,7 @@ public int executeUpdate(final String sql, final String[] columnNames) throws SQ
returnGeneratedKeys = true;
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand Down Expand Up @@ -431,6 +438,7 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement, final
private boolean execute0(final String sql, final ExecuteCallback callback) throws SQLException {
try {
LogicSQL logicSQL = createLogicSQL(sql);
checkSameDatabaseNameInTransaction(logicSQL.getSqlStatementContext(), connection.getDatabaseName());
trafficInstanceId = getInstanceIdAndSet(logicSQL).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, logicSQL);
Expand All @@ -455,6 +463,19 @@ private boolean execute0(final String sql, final ExecuteCallback callback) throw
}
}

private void checkSameDatabaseNameInTransaction(final SQLStatementContext<?> sqlStatementContext, final String connectionDatabaseName) {
if (!TransactionHolder.isTransaction()) {
return;
}
if (sqlStatementContext instanceof TableAvailable) {
((TableAvailable) sqlStatementContext).getTablesContext().getDatabaseName().ifPresent(databaseName -> {
if (!databaseName.equals(connectionDatabaseName)) {
throw new ShardingSphereException("JDBC does not support operations across multiple logical databases in transaction.");
}
});
}
}

private JDBCExecutionUnit createTrafficExecutionUnit(final String trafficInstanceId, final LogicSQL logicSQL) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(logicSQL.getSql(), logicSQL.getParameters()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public final class ConnectionTransaction {

private final TransactionType transactionType;

private final String databaseName;

@Setter
@Getter
private volatile boolean rollbackOnly;
Expand All @@ -46,6 +48,7 @@ public ConnectionTransaction(final String databaseName, final TransactionRule ru
}

public ConnectionTransaction(final String databaseName, final TransactionType transactionType, final TransactionRule rule) {
this.databaseName = databaseName;
this.transactionType = transactionType;
transactionManager = rule.getResource().getTransactionManager(transactionType);
TransactionTypeHolder.set(transactionType);
Expand Down Expand Up @@ -87,7 +90,7 @@ public boolean isHoldTransaction(final boolean autoCommit) {
* @throws SQLException SQL exception
*/
public Optional<Connection> getConnection(final String dataSourceName) throws SQLException {
return isInTransaction() ? Optional.of(transactionManager.getConnection(dataSourceName)) : Optional.empty();
return isInTransaction() ? Optional.of(transactionManager.getConnection(this.databaseName, dataSourceName)) : Optional.empty();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.transaction.core;

import com.google.common.base.Preconditions;
import lombok.Getter;

import javax.sql.DataSource;
Expand All @@ -34,8 +35,10 @@ public final class ResourceDataSource {
private final DataSource dataSource;

public ResourceDataSource(final String originalName, final DataSource dataSource) {
String[] databaseAndDataSourceName = originalName.split("\\.");
Preconditions.checkState(2 == databaseAndDataSourceName.length, String.format("Database and data source name must be provided,`%s`.", originalName));
this.originalName = originalName;
this.dataSource = dataSource;
uniqueResourceName = ResourceIDGenerator.getInstance().nextId() + originalName;
this.uniqueResourceName = ResourceIDGenerator.getInstance().nextId() + databaseAndDataSourceName[1];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ private synchronized ShardingSphereTransactionManagerEngine createTransactionMan
Map<String, DataSource> dataSourceMap = new HashMap<>(databases.size());
Set<DatabaseType> databaseTypes = new HashSet<>();
for (Entry<String, ShardingSphereDatabase> entry : databases.entrySet()) {
dataSourceMap.putAll(entry.getValue().getResource().getDataSources());
ShardingSphereDatabase database = entry.getValue();
database.getResource().getDataSources().forEach((key, value) -> {
dataSourceMap.put(database.getName() + "." + key, value);
});
if (null != entry.getValue().getResource().getDatabaseType()) {
databaseTypes.add(entry.getValue().getResource().getDatabaseType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ public interface ShardingSphereTransactionManager extends AutoCloseable {
/**
* Get transactional connection.
*
* @param databaseName database name
* @param dataSourceName data source name
* @return connection
* @throws SQLException SQL exception
*/
Connection getConnection(String dataSourceName) throws SQLException;
Connection getConnection(String databaseName, String dataSourceName) throws SQLException;

/**
* Begin transaction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,22 @@

public final class ResourceDataSourceTest {

private static final String DATABASE_NAME = "sharding_db";

private static final String DATA_SOURCE_NAME = "fooDataSource";

@Test
public void assertNewInstance() {
ResourceDataSource actual = new ResourceDataSource("fooDataSource", new MockedDataSource());
assertThat(actual.getOriginalName(), is("fooDataSource"));
String originalName = DATABASE_NAME + "." + DATA_SOURCE_NAME;
ResourceDataSource actual = new ResourceDataSource(originalName, new MockedDataSource());
assertThat(actual.getOriginalName(), is(originalName));
assertThat(actual.getDataSource(), instanceOf(MockedDataSource.class));
assertTrue(actual.getUniqueResourceName().startsWith("resource"));
assertTrue(actual.getUniqueResourceName().endsWith("fooDataSource"));
assertTrue(actual.getUniqueResourceName().endsWith(DATA_SOURCE_NAME));
}

@Test(expected = IllegalStateException.class)
public void assertDataSourceNameOnlyFailure() {
new ResourceDataSource(DATA_SOURCE_NAME, new MockedDataSource());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public boolean isInTransaction() {
}

@Override
public Connection getConnection(final String dataSourceName) {
public Connection getConnection(final String databaseName, final String dataSourceName) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public boolean isInTransaction() {
}

@Override
public Connection getConnection(final String dataSourceName) {
public Connection getConnection(final String databaseName, final String dataSourceName) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ public boolean isInTransaction() {
}

@Override
public Connection getConnection(final String dataSourceName) throws SQLException {
public Connection getConnection(final String databaseName, final String dataSourceName) throws SQLException {
Preconditions.checkState(enableSeataAT, "sharding seata-at transaction has been disabled.");
return dataSourceMap.get(dataSourceName).getConnection();
return dataSourceMap.get(databaseName + "." + dataSourceName).getConnection();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public final class SeataATShardingSphereTransactionManagerTest {

private static final MockSeataServer MOCK_SEATA_SERVER = new MockSeataServer();

private static final String DATA_SOURCE_UNIQUE_NAME = "sharding_db.foo_ds";

private final SeataATShardingSphereTransactionManager seataTransactionManager = new SeataATShardingSphereTransactionManager();

private final Queue<Object> requestQueue = MOCK_SEATA_SERVER.getMessageHandler().getRequestQueue();
Expand All @@ -84,7 +86,7 @@ public static void after() {

@Before
public void setUp() {
seataTransactionManager.init(DatabaseTypeFactory.getInstance("MySQL"), Collections.singletonList(new ResourceDataSource("foo_ds", new MockedDataSource())), "Seata");
seataTransactionManager.init(DatabaseTypeFactory.getInstance("MySQL"), Collections.singletonList(new ResourceDataSource(DATA_SOURCE_UNIQUE_NAME, new MockedDataSource())), "Seata");
}

@After
Expand All @@ -102,13 +104,13 @@ public void tearDown() {
public void assertInit() {
Map<String, DataSource> actual = getDataSourceMap();
assertThat(actual.size(), is(1));
assertThat(actual.get("foo_ds"), instanceOf(DataSourceProxy.class));
assertThat(actual.get(DATA_SOURCE_UNIQUE_NAME), instanceOf(DataSourceProxy.class));
assertThat(seataTransactionManager.getTransactionType(), is(TransactionType.BASE));
}

@Test
public void assertGetConnection() throws SQLException {
Connection actual = seataTransactionManager.getConnection("foo_ds");
Connection actual = seataTransactionManager.getConnection("sharding_db", "foo_ds");
assertThat(actual, instanceOf(ConnectionProxy.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ public boolean isInTransaction() {
}

@Override
public Connection getConnection(final String dataSourceName) throws SQLException {
public Connection getConnection(final String databaseName, final String dataSourceName) throws SQLException {
try {
return cachedDataSources.get(dataSourceName).getConnection();
return cachedDataSources.get(databaseName + "." + dataSourceName).getConnection();
} catch (final SystemException | RollbackException ex) {
throw new SQLException(ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ public void assertIsInTransaction() {
@Test
public void assertGetConnection() throws SQLException {
xaTransactionManager.begin();
Connection actual1 = xaTransactionManager.getConnection("ds1");
Connection actual2 = xaTransactionManager.getConnection("ds2");
Connection actual3 = xaTransactionManager.getConnection("ds3");
Connection actual1 = xaTransactionManager.getConnection("demo_ds_1", "ds1");
Connection actual2 = xaTransactionManager.getConnection("demo_ds_2", "ds2");
Connection actual3 = xaTransactionManager.getConnection("demo_ds_3", "ds3");
assertThat(actual1, instanceOf(Connection.class));
assertThat(actual2, instanceOf(Connection.class));
assertThat(actual3, instanceOf(Connection.class));
Expand All @@ -93,10 +93,10 @@ public void assertGetConnection() throws SQLException {

@Test
public void assertGetConnectionOfNestedTransaction() throws SQLException {
ThreadLocal<Map<Transaction, Connection>> transactions = getEnlistedTransactions(getCachedDataSources().get("ds1"));
ThreadLocal<Map<Transaction, Connection>> transactions = getEnlistedTransactions(getCachedDataSources().get("demo_ds_1.ds1"));
xaTransactionManager.begin();
assertTrue(transactions.get().isEmpty());
xaTransactionManager.getConnection("ds1");
xaTransactionManager.getConnection("demo_ds_1", "ds1");
assertThat(transactions.get().size(), is(1));
executeNestedTransaction(transactions);
assertThat(transactions.get().size(), is(1));
Expand All @@ -106,7 +106,7 @@ public void assertGetConnectionOfNestedTransaction() throws SQLException {

private void executeNestedTransaction(final ThreadLocal<Map<Transaction, Connection>> transactions) throws SQLException {
xaTransactionManager.begin();
xaTransactionManager.getConnection("ds1");
xaTransactionManager.getConnection("demo_ds_1", "ds1");
assertThat(transactions.get().size(), is(2));
xaTransactionManager.commit(false);
assertThat(transactions.get().size(), is(1));
Expand Down Expand Up @@ -153,9 +153,9 @@ private ThreadLocal<Map<Transaction, Connection>> getEnlistedTransactions(final

private Collection<ResourceDataSource> createResourceDataSources(final DatabaseType databaseType) {
List<ResourceDataSource> result = new LinkedList<>();
result.add(new ResourceDataSource("ds1", DataSourceUtils.build(HikariDataSource.class, databaseType, "demo_ds_1")));
result.add(new ResourceDataSource("ds2", DataSourceUtils.build(HikariDataSource.class, databaseType, "demo_ds_2")));
result.add(new ResourceDataSource("ds3", DataSourceUtils.build(AtomikosDataSourceBean.class, databaseType, "demo_ds_3")));
result.add(new ResourceDataSource("demo_ds_1.ds1", DataSourceUtils.build(HikariDataSource.class, databaseType, "demo_ds_1")));
result.add(new ResourceDataSource("demo_ds_2.ds2", DataSourceUtils.build(HikariDataSource.class, databaseType, "demo_ds_2")));
result.add(new ResourceDataSource("demo_ds_3.ds3", DataSourceUtils.build(AtomikosDataSourceBean.class, databaseType, "demo_ds_3")));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private List<Connection> createConnections(final String databaseName, final Stri
private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource, final TransactionType transactionType) throws SQLException {
TransactionRule transactionRule = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
ShardingSphereTransactionManager transactionManager = transactionRule.getResource().getTransactionManager(transactionType);
Connection result = isInTransaction(transactionManager) ? transactionManager.getConnection(dataSourceName) : dataSource.getConnection();
Connection result = isInTransaction(transactionManager) ? transactionManager.getConnection(databaseName, dataSourceName) : dataSource.getConnection();
if (dataSourceName.contains(".")) {
String catalog = dataSourceName.split("\\.")[1];
result.setCatalog(catalog);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import lombok.Setter;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.ExecutorStatementManager;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.session.ConnectionContext;
Expand Down Expand Up @@ -104,9 +103,6 @@ public void setCurrentDatabase(final String databaseName) {
if (null != databaseName && databaseName.equals(this.databaseName)) {
return;
}
if (transactionStatus.isInTransaction()) {
throw new ShardingSphereException("Failed to switch database, please terminate current transaction.");
}
if (statementManager instanceof JDBCBackendStatement) {
((JDBCBackendStatement) statementManager).setDatabaseName(databaseName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ public void assertFailedSwitchTransactionTypeWhileBegin() throws SQLException {
connectionSession.getTransactionStatus().setTransactionType(TransactionType.XA);
}

@Test(expected = ShardingSphereException.class)
public void assertFailedSwitchSchemaWhileBegin() throws SQLException {
@Test
public void assertSwitchSchemaWhileBegin() throws SQLException {
connectionSession.setCurrentDatabase("db");
JDBCBackendTransactionManager transactionManager = new JDBCBackendTransactionManager(backendConnection);
transactionManager.begin();
Expand Down

0 comments on commit e16dff4

Please sign in to comment.