diff --git a/v2/spanner-common/pom.xml b/v2/spanner-common/pom.xml index 78327e7067..d62c98c91f 100644 --- a/v2/spanner-common/pom.xml +++ b/v2/spanner-common/pom.xml @@ -44,10 +44,17 @@ 1.0-SNAPSHOT compile + + com.datastax.oss + java-driver-core + 4.17.0 + compile + - com.datastax.oss - java-driver-core - 4.17.0 + org.mockito + mockito-inline + 3.12.4 + test diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/metadata/CassandraSourceMetadata.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/metadata/CassandraSourceMetadata.java index 9634e5b082..cdc4ea857e 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/metadata/CassandraSourceMetadata.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/metadata/CassandraSourceMetadata.java @@ -17,7 +17,11 @@ import autovalue.shaded.com.google.common.collect.ImmutableList; import com.datastax.oss.driver.api.core.cql.ResultSet; -import com.google.cloud.teleport.v2.spanner.migrations.schema.*; +import com.google.cloud.teleport.v2.spanner.migrations.schema.ColumnPK; +import com.google.cloud.teleport.v2.spanner.migrations.schema.NameAndCols; +import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnDefinition; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; import com.google.cloud.teleport.v2.spanner.migrations.schema.cassandra.SourceColumn; import com.google.cloud.teleport.v2.spanner.migrations.schema.cassandra.SourceSchema; import com.google.cloud.teleport.v2.spanner.migrations.schema.cassandra.SourceTable; diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/CassandraShard.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/CassandraShard.java index 0ad6b09c7c..f0a608f556 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/CassandraShard.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/CassandraShard.java @@ -15,124 +15,103 @@ */ package com.google.cloud.teleport.v2.spanner.migrations.shard; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import java.util.List; import java.util.Objects; public class CassandraShard extends Shard { - private String keyspace; - private String consistencyLevel = "LOCAL_QUORUM"; - private boolean sslOptions = false; - private String protocolVersion = "v5"; - private String dataCenter = "datacenter1"; - private int localPoolSize = 1024; - private int remotePoolSize = 256; - - public CassandraShard( - String logicalShardId, - String host, - String port, - String user, - String password, - String keyspace, - String consistencyLevel, - Boolean sslOptions, - String protocolVersion, - String dataCenter, - Integer localPoolSize, - Integer remotePoolSize) { - super(logicalShardId, host, port, user, password, null, null, null, null); - this.keyspace = keyspace; - this.consistencyLevel = consistencyLevel; - this.sslOptions = sslOptions; - this.protocolVersion = protocolVersion; - this.dataCenter = dataCenter; - this.localPoolSize = localPoolSize; - this.remotePoolSize = remotePoolSize; - } - - // Getters - public String getKeySpaceName() { - return keyspace; - } + private final DriverConfigLoader configLoader; - public String getConsistencyLevel() { - return consistencyLevel; + public CassandraShard(DriverConfigLoader configLoader) { + super(null, null, null, null, null, null, null, null, null); + this.configLoader = configLoader; + validateFields(); + extractAndSetHostAndPort(); } - public boolean getSSLOptions() { - return sslOptions; + private void validateFields() { + if (getContactPoints() == null || getContactPoints().isEmpty()) { + throw new IllegalArgumentException("CONTACT_POINTS cannot be null or empty."); + } + if (getKeySpaceName() == null || getKeySpaceName().isEmpty()) { + throw new IllegalArgumentException("SESSION_KEYSPACE cannot be null or empty."); + } } - public String getProtocolVersion() { - return protocolVersion; - } + private void extractAndSetHostAndPort() { + String firstContactPoint = getContactPoints().get(0); + String[] parts = firstContactPoint.split(":"); - public String getDataCenter() { - return dataCenter; - } + if (parts.length < 2) { + throw new IllegalArgumentException("Invalid contact point format: " + firstContactPoint); + } - public int getLocalPoolSize() { - return localPoolSize; + String host = parts[0]; + String port = parts[1]; + + setHost(host); + setPort(port); } - public int getRemotePoolSize() { - return remotePoolSize; + private DriverExecutionProfile getProfile() { + return configLoader.getInitialConfig().getDefaultProfile(); } - // Setters - public void setKeySpaceName(String keySpaceName) { - this.keyspace = keySpaceName; + // Getters that fetch data from DriverConfigLoader + public List getContactPoints() { + return getProfile().getStringList(DefaultDriverOption.CONTACT_POINTS); } - public void setConsistencyLevel(String consistencyLevel) { - this.consistencyLevel = consistencyLevel; + public String getKeySpaceName() { + return getProfile().getString(DefaultDriverOption.SESSION_KEYSPACE); } - public void setSslOptions(boolean sslOptions) { - this.sslOptions = sslOptions; + public String getConsistencyLevel() { + return getProfile().getString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_QUORUM"); } - public void setProtocolVersion(String protocolVersion) { - this.protocolVersion = protocolVersion; + public boolean isSslEnabled() { + return getProfile().getBoolean(DefaultDriverOption.SSL_ENGINE_FACTORY_CLASS, false); } - public void setDataCenter(String dataCenter) { - this.dataCenter = dataCenter; + public String getProtocolVersion() { + return getProfile().getString(DefaultDriverOption.PROTOCOL_VERSION, "V5"); } - public void setLocalPoolSize(int localPoolSize) { - this.localPoolSize = localPoolSize; + public String getDataCenter() { + return getProfile() + .getString(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER, "datacenter1"); } - public void setRemotePoolSize(int remotePoolSize) { - this.remotePoolSize = remotePoolSize; + public int getLocalPoolSize() { + return getProfile().getInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 1024); } - public void validate() { - validateField(getHost(), "Host"); - validateField(getPort(), "Port"); - validateField(getUserName(), "Username"); - validateField(getPassword(), "Password"); - validateField(getKeySpaceName(), "Keyspace"); + public int getRemotePoolSize() { + return getProfile().getInt(DefaultDriverOption.CONNECTION_POOL_REMOTE_SIZE, 256); } - private void validateField(String value, String fieldName) { - if (value == null || value.isEmpty()) { - throw new IllegalArgumentException(fieldName + " is required"); - } + public DriverConfigLoader getConfigLoader() { + return configLoader; } @Override public String toString() { return String.format( - "CassandraShard{logicalShardId='%s', host='%s', port='%s', user='%s', keySpaceName='%s', datacenter='%s', consistencyLevel='%s', protocolVersion='%s'}", + "CassandraShard{logicalShardId='%s', contactPoints=%s, keyspace='%s', consistencyLevel='%s', sslOptions=%b, protocolVersion='%s', dataCenter='%s', localPoolSize=%d, remotePoolSize=%d, host='%s', port='%s'}", getLogicalShardId(), - getHost(), - getPort(), - getUserName(), + getContactPoints(), getKeySpaceName(), - getDataCenter(), getConsistencyLevel(), - getProtocolVersion()); + isSslEnabled(), + getProtocolVersion(), + getDataCenter(), + getLocalPoolSize(), + getRemotePoolSize(), + getHost(), + getPort()); } @Override @@ -140,34 +119,26 @@ public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof CassandraShard)) return false; CassandraShard that = (CassandraShard) o; - return sslOptions == that.sslOptions - && localPoolSize == that.localPoolSize - && remotePoolSize == that.remotePoolSize - && Objects.equals(getLogicalShardId(), that.getLogicalShardId()) - && Objects.equals(getHost(), that.getHost()) - && Objects.equals(getPort(), that.getPort()) - && Objects.equals(getUserName(), that.getUserName()) - && Objects.equals(getPassword(), that.getPassword()) - && Objects.equals(keyspace, that.keyspace) - && Objects.equals(dataCenter, that.dataCenter) - && Objects.equals(consistencyLevel, that.consistencyLevel) - && Objects.equals(protocolVersion, that.protocolVersion); + return Objects.equals(getContactPoints(), that.getContactPoints()) + && Objects.equals(getKeySpaceName(), that.getKeySpaceName()) + && Objects.equals(getConsistencyLevel(), that.getConsistencyLevel()) + && Objects.equals(isSslEnabled(), that.isSslEnabled()) + && Objects.equals(getProtocolVersion(), that.getProtocolVersion()) + && Objects.equals(getDataCenter(), that.getDataCenter()) + && Objects.equals(getLocalPoolSize(), that.getLocalPoolSize()) + && Objects.equals(getRemotePoolSize(), that.getRemotePoolSize()); } @Override public int hashCode() { return Objects.hash( - getLogicalShardId(), - getHost(), - getPort(), - getUserName(), - getPassword(), - keyspace, - dataCenter, - consistencyLevel, - protocolVersion, - sslOptions, - localPoolSize, - remotePoolSize); + getContactPoints(), + getKeySpaceName(), + getConsistencyLevel(), + isSslEnabled(), + getProtocolVersion(), + getDataCenter(), + getLocalPoolSize(), + getRemotePoolSize()); } } diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/CassandraConfigFileReader.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/CassandraConfigFileReader.java index 0e05cc7a57..2f9935fd1e 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/CassandraConfigFileReader.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/CassandraConfigFileReader.java @@ -15,19 +15,12 @@ */ package com.google.cloud.teleport.v2.spanner.migrations.utils; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; -import com.google.gson.FieldNamingPolicy; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; import java.io.IOException; -import java.io.InputStream; -import java.nio.channels.Channels; -import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; -import org.apache.beam.sdk.io.FileSystems; -import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,8 +31,6 @@ public class CassandraConfigFileReader { private static final Logger LOG = LoggerFactory.getLogger(CassandraConfigFileReader.class); - private static final Gson GSON = - new GsonBuilder().setFieldNamingPolicy(FieldNamingPolicy.IDENTITY).create(); /** * Reads the Cassandra configuration file from the specified GCS path and converts it into a list @@ -49,28 +40,20 @@ public class CassandraConfigFileReader { * @return a list containing the parsed CassandraShard. */ public List getCassandraShard(String cassandraConfigFilePath) { - try (InputStream stream = getFileInputStream(cassandraConfigFilePath)) { - String configContent = IOUtils.toString(stream, StandardCharsets.UTF_8); - CassandraShard shard = GSON.fromJson(configContent, CassandraShard.class); - - LOG.info("Successfully read Cassandra config: {}", shard); + try { + LOG.info("Reading Cassandra configuration from: {}", cassandraConfigFilePath); + DriverConfigLoader configLoader = + CassandraDriverConfigLoader.loadFile(cassandraConfigFilePath); + CassandraShard shard = new CassandraShard(configLoader); + LOG.info("Successfully created CassandraShard: {}", shard); return Collections.singletonList(shard); } catch (IOException e) { String errorMessage = - "Failed to read Cassandra config file. Ensure it is ASCII or UTF-8 encoded and contains a well-formed JSON string."; + String.format( + "Failed to read Cassandra config file from path: %s. Ensure it exists, is accessible, and is properly formatted.", + cassandraConfigFilePath); LOG.error(errorMessage, e); throw new RuntimeException(errorMessage, e); } } - - /** - * Retrieves an InputStream for the specified GCS file path. - * - * @param filePath the GCS file path. - * @return an InputStream for the file. - * @throws IOException if the file cannot be accessed or opened. - */ - private InputStream getFileInputStream(String filePath) throws IOException { - return Channels.newInputStream(FileSystems.open(FileSystems.matchNewResource(filePath, false))); - } } diff --git a/v2/spanner-to-sourcedb/pom.xml b/v2/spanner-to-sourcedb/pom.xml index 63bd047ed6..c2ed414902 100644 --- a/v2/spanner-to-sourcedb/pom.xml +++ b/v2/spanner-to-sourcedb/pom.xml @@ -100,6 +100,12 @@ ${project.version} test + + org.junit.jupiter + junit-jupiter-api + 5.5.2 + test + diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelper.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelper.java index d7fe8eb556..a2f2646dce 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelper.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelper.java @@ -17,13 +17,13 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; -import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfigLoader; -import com.datastax.oss.driver.api.core.config.ProgrammaticDriverConfigLoaderBuilder; import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.utils.CassandraDriverConfigLoader; import com.google.cloud.teleport.v2.templates.exceptions.ConnectionException; import com.google.cloud.teleport.v2.templates.models.ConnectionHelperRequest; +import java.io.FileNotFoundException; import java.net.InetSocketAddress; import java.util.List; import java.util.Map; @@ -83,7 +83,6 @@ public synchronized void init(ConnectionHelperRequest connectionHelperRequest) { CassandraShard cassandraShard = (CassandraShard) shard; try { - cassandraShard.validate(); CqlSession session = createCqlSession(cassandraShard); String connectionKey = generateConnectionKey(cassandraShard); connectionPoolMap.put(connectionKey, session); @@ -137,15 +136,21 @@ public boolean isConnectionPoolInitialized() { * @return A {@link CqlSession} instance. */ private CqlSession createCqlSession(CassandraShard cassandraShard) { - CqlSessionBuilder builder = - CqlSession.builder() - .addContactPoint( - new InetSocketAddress( - cassandraShard.getHost(), Integer.parseInt(cassandraShard.getPort()))) - .withAuthCredentials(cassandraShard.getUserName(), cassandraShard.getPassword()) - .withKeyspace(cassandraShard.getKeySpaceName()); - - DriverConfigLoader configLoader = createConfigLoader(cassandraShard); + CqlSessionBuilder builder = CqlSession.builder(); + + for (String contactPoint : cassandraShard.getContactPoints()) { + String[] parts = contactPoint.split(":"); + String host = parts[0]; + int port = Integer.parseInt(parts[1]); + builder.addContactPoint(new InetSocketAddress(host, port)); + } + + builder + .withAuthCredentials(cassandraShard.getUserName(), cassandraShard.getPassword()) + .withKeyspace(cassandraShard.getKeySpaceName()); + + DriverConfigLoader configLoader = cassandraShard.getConfigLoader(); + configLoader.getInitialConfig(); builder.withConfigLoader(configLoader); return builder.build(); @@ -164,21 +169,26 @@ private String generateConnectionKey(CassandraShard shard) { } /** - * Creates a driver configuration loader for the given {@link CassandraShard}. + * Loads the Cassandra driver configuration from the specified file path. * - * @param cassandraShard The shard containing configuration details. - * @return A {@link DriverConfigLoader} instance. + *

This method uses the provided `configFilePath` to load the Cassandra driver configuration + * using the {@link CassandraDriverConfigLoader}. If the configuration file is not found, an error + * is logged, and a {@link RuntimeException} is thrown. + * + * @param configFilePath The path to the Cassandra driver configuration file. This should be a + * valid path pointing to a configuration file (e.g., "gs://path/to/cassandra_config.yaml"). + * @return A {@link DriverConfigLoader} that contains the loaded Cassandra driver configuration. + * @throws RuntimeException If an error occurs while loading the configuration (e.g., if the file + * is not found). The underlying {@link FileNotFoundException} will be wrapped in a {@link + * RuntimeException}. */ - private DriverConfigLoader createConfigLoader(CassandraShard cassandraShard) { - ProgrammaticDriverConfigLoaderBuilder configLoaderBuilder = - DriverConfigLoader.programmaticBuilder(); - - configLoaderBuilder - .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, cassandraShard.getLocalPoolSize()) - .withInt( - DefaultDriverOption.CONNECTION_POOL_REMOTE_SIZE, cassandraShard.getRemotePoolSize()); - - return configLoaderBuilder.build(); + private DriverConfigLoader loadDriverConfig(String configFilePath) { + try { + return CassandraDriverConfigLoader.loadFile(configFilePath); + } catch (FileNotFoundException e) { + LOG.error("Could not load Cassandra driver configuration from path: {}", configFilePath, e); + throw new RuntimeException("Error loading Cassandra configuration", e); + } } /** diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java new file mode 100644 index 0000000000..5a0ce23037 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.connection; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.templates.exceptions.ConnectionException; +import com.google.cloud.teleport.v2.templates.models.ConnectionHelperRequest; +import java.util.Arrays; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; + +class CassandraConnectionHelperTest { + + @Mock private CassandraShard cassandraShard; + @Mock private CassandraConnectionHelper connectionHelper; + + @BeforeEach + void setUp() { + connectionHelper = new CassandraConnectionHelper(); + cassandraShard = mock(CassandraShard.class); + } + + @Test + void testInit_ShouldInitializeConnectionPool() { + when(cassandraShard.getHost()).thenReturn("localhost"); + when(cassandraShard.getPort()).thenReturn("9042"); + when(cassandraShard.getUserName()).thenReturn("user"); + when(cassandraShard.getPassword()).thenReturn("password"); + when(cassandraShard.getKeySpaceName()).thenReturn("mykeyspace"); + + ConnectionHelperRequest request = mock(ConnectionHelperRequest.class); + when(request.getShards()).thenReturn(Arrays.asList(cassandraShard)); + when(request.getMaxConnections()).thenReturn(10); + connectionHelper.init(request); + assertTrue(connectionHelper.isConnectionPoolInitialized()); + } + + @Test + void testGetConnection_ShouldReturnValidSession() throws ConnectionException { + String connectionKey = "localhost:9042/user/mykeyspace"; + CqlSession mockSession = mock(CqlSession.class); + connectionHelper.setConnectionPoolMap(Map.of(connectionKey, mockSession)); + + CqlSession session = connectionHelper.getConnection(connectionKey); + + assertNotNull(session); + assertEquals(mockSession, session); + } + + @Test + void testGetConnection_ShouldThrowException_WhenConnectionNotFound() { + assertThrows( + ConnectionException.class, + () -> { + connectionHelper.getConnection("invalidKey"); + }); + } + + @Test + void testIsConnectionPoolInitialized_ShouldReturnTrue_WhenInitialized() { + ConnectionHelperRequest request = mock(ConnectionHelperRequest.class); + when(request.getShards()).thenReturn(Arrays.asList(mock(CassandraShard.class))); + when(request.getMaxConnections()).thenReturn(10); + + connectionHelper.init(request); + + assertTrue(connectionHelper.isConnectionPoolInitialized()); + } + + @Test + void testGetConnection_ShouldThrowConnectionException_WhenPoolNotInitialized() { + assertThrows( + ConnectionException.class, + () -> { + connectionHelper.getConnection("anyKey"); + }); + } + + @Test + void testInit_ShouldHandleException_WhenCqlSessionCreationFails() { + CassandraShard invalidShard = mock(CassandraShard.class); + when(invalidShard.getHost()).thenReturn("localhost"); + when(invalidShard.getPort()).thenReturn("9042"); + when(invalidShard.getUserName()).thenReturn("invalidUser"); + when(invalidShard.getPassword()).thenReturn("invalidPassword"); + when(invalidShard.getKeySpaceName()).thenReturn("mykeyspace"); + + ConnectionHelperRequest request = mock(ConnectionHelperRequest.class); + when(request.getShards()).thenReturn(Arrays.asList(invalidShard)); + when(request.getMaxConnections()).thenReturn(10); + + connectionHelper.init(request); + assertFalse(connectionHelper.isConnectionPoolInitialized()); + } + + @Test + void testSetConnectionPoolMap_ShouldOverrideConnectionPoolMap() throws ConnectionException { + CqlSession mockSession = mock(CqlSession.class); + connectionHelper.setConnectionPoolMap(Map.of("localhost:9042/user/mykeyspace", mockSession)); + + CqlSession session = connectionHelper.getConnection("localhost:9042/user/mykeyspace"); + assertNotNull(session); + assertEquals(mockSession, session); + } + + @Test + void testGetConnectionPoolNotFound() { + connectionHelper.setConnectionPoolMap(Map.of()); + + ConnectionException exception = + assertThrows( + ConnectionException.class, + () -> { + connectionHelper.getConnection("nonexistentKey"); + }); + + assertEquals("Connection pool is not initialized.", exception.getMessage()); + } + + @Test + void testGetConnectionWhenPoolNotInitialized() { + connectionHelper.setConnectionPoolMap(null); + ConnectionException exception = + assertThrows( + ConnectionException.class, + () -> { + connectionHelper.getConnection("localhost:9042/testuser/testKeyspace"); + }); + assertEquals("Connection pool is not initialized.", exception.getMessage()); + } + + @Test + void testGetConnectionWithValidKey() throws ConnectionException { + CqlSession mockSession = mock(CqlSession.class); + + String connectionKey = "localhost:9042/testuser/testKeyspace"; + connectionHelper.setConnectionPoolMap(Map.of(connectionKey, mockSession)); + + CqlSession session = connectionHelper.getConnection(connectionKey); + + assertEquals(mockSession, session, "The returned connection should match the mock session."); + } + + @Test + void testInit_ShouldThrowIllegalArgumentException_WhenInvalidShardTypeIsProvideds() { + Shard invalidShard = mock(Shard.class); + CassandraConnectionHelper connectionHelper = new CassandraConnectionHelper(); + ConnectionHelperRequest request = mock(ConnectionHelperRequest.class); + when(request.getShards()).thenReturn(java.util.Collections.singletonList(invalidShard)); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + connectionHelper.init(request); + }); + assertEquals("Invalid shard object", exception.getMessage()); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java index 3b9c0e64bf..b595b4f7ca 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java @@ -18,9 +18,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; +import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; import com.google.cloud.teleport.v2.templates.constants.Constants; +import com.google.cloud.teleport.v2.templates.dbutils.connection.CassandraConnectionHelper; import com.google.cloud.teleport.v2.templates.dbutils.connection.JdbcConnectionHelper; +import com.google.cloud.teleport.v2.templates.dbutils.dao.source.CassandraDao; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.JdbcDao; import com.google.cloud.teleport.v2.templates.dbutils.dml.MySQLDMLGenerator; import com.google.cloud.teleport.v2.templates.exceptions.UnsupportedSourceException; @@ -82,4 +85,32 @@ public void testCreateSourceProcessor_invalidSource() throws Exception { SourceProcessorFactory.createSourceProcessor("invalid_source", shards, maxConnections); } + + @Test + public void testCreateSourceProcessor_cassandra_validSource() throws Exception { + CassandraShard mockCassandraShard = Mockito.mock(CassandraShard.class); + Mockito.when(mockCassandraShard.getContactPoints()).thenReturn(List.of("localhost:9042")); + Mockito.when(mockCassandraShard.getKeySpaceName()).thenReturn("mydatabase"); + Mockito.when(mockCassandraShard.getLogicalShardId()).thenReturn("shard1"); + Mockito.when(mockCassandraShard.getConsistencyLevel()).thenReturn("LOCAL_QUORUM"); + Mockito.when(mockCassandraShard.getProtocolVersion()).thenReturn("v5"); + Mockito.when(mockCassandraShard.getLocalPoolSize()).thenReturn(1024); + Mockito.when(mockCassandraShard.getRemotePoolSize()).thenReturn(1024); + + List shards = List.of(mockCassandraShard); + int maxConnections = 10; + CassandraConnectionHelper mockConnectionHelper = Mockito.mock(CassandraConnectionHelper.class); + doNothing().when(mockConnectionHelper).init(any()); + SourceProcessorFactory.setConnectionHelperMap( + Map.of(Constants.SOURCE_CASSANDRA, mockConnectionHelper)); + SourceProcessor processor = + SourceProcessorFactory.createSourceProcessor( + Constants.SOURCE_CASSANDRA, shards, maxConnections); + + Assert.assertNotNull(processor); + // ToDo this Particular line will get enable in DML PR + // Assert.assertTrue(processor.getDmlGenerator() instanceof CassandraDMLGenerator); + Assert.assertEquals(1, processor.getSourceDaoMap().size()); + Assert.assertTrue(processor.getSourceDaoMap().get("shard1") instanceof CassandraDao); + } }