Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cassandra rr custom transformation it - [To check pipeline test] #117

Closed
wants to merge 13 commits into from
Closed
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (C) 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.custom;

import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException;
import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer;
import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest;
import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

// TODO: Rename the class since its being used in both Live and Reverse replication tests and in
// both ITs and LTs
public class CustomTransformationWithCassandraForLiveIT implements ISpannerMigrationTransformer {

private static final Logger LOG = LoggerFactory.getLogger(CustomShardIdFetcher.class);

@Override
public void init(String parameters) {
LOG.info("init called with {}", parameters);
}

@Override
public MigrationTransformationResponse toSpannerRow(MigrationTransformationRequest request)
throws InvalidTransformationException {
if (request.getTableName().equals("Customers")) {
Map<String, Object> row = new HashMap<>(request.getRequestRow());
row.put("full_name", row.get("first_name") + " " + row.get("last_name"));
row.put("migration_shard_id", request.getShardId() + "_" + row.get("id"));
MigrationTransformationResponse response = new MigrationTransformationResponse(row, false);
return response;
}
return new MigrationTransformationResponse(null, false);
}

@Override
public MigrationTransformationResponse toSourceRow(MigrationTransformationRequest request)
throws InvalidTransformationException {
if (request.getTableName().equals("customers")) {
Map<String, Object> requestRow = request.getRequestRow();
Map<String, Object> row = new HashMap<>();
row.put("full_name", requestRow.get("first_name") + " " + requestRow.get("last_name"));
MigrationTransformationResponse response = new MigrationTransformationResponse(row, false);
return response;
}
return new MigrationTransformationResponse(null, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public interface Options extends PipelineOptions, StreamingOptions {
optional = true,
description = "Cloud Spanner shadow table prefix.",
helpText = "The prefix used to name shadow tables. Default: `shadow_`.")
@Default.String("shadow_")
@Default.String("rev_shadow_")
String getShadowTablePrefix();

void setShadowTablePrefix(String value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
dmlGeneratorRequest.getSourceDbTimezoneOffset());
dmlGeneratorRequest.getSourceDbTimezoneOffset(),
dmlGeneratorRequest.getCustomTransformationResponse());
if (pkColumnNameValues == null) {
LOG.warn(
"Failed to generate primary key values for table {}. Skipping the record.",
Expand Down Expand Up @@ -166,7 +167,8 @@ private static DMLGeneratorResponse generateDMLResponse(
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
dmlGeneratorRequest.getSourceDbTimezoneOffset());
dmlGeneratorRequest.getSourceDbTimezoneOffset(),
dmlGeneratorRequest.getCustomTransformationResponse());
Map<String, PreparedStatementValueObject<?>> allColumnNamesAndValues =
ImmutableMap.<String, PreparedStatementValueObject<?>>builder()
.putAll(pkColumnNameValues)
Expand Down Expand Up @@ -287,6 +289,7 @@ private static DMLGeneratorResponse getDeleteStatementCQL(
* @param newValuesJson the JSON object containing new values for columns.
* @param keyValuesJson the JSON object containing key values for columns.
* @param sourceDbTimezoneOffset the timezone offset of the source database.
* @param customTransformationResponse the custom transformation
* @return a map of column names to their corresponding prepared statement value objects.
* <p>This method: 1. Iterates over the non-primary key column definitions in the source table
* schema. 2. Maps each column in the source table schema to its corresponding column in the
Expand All @@ -299,9 +302,14 @@ private static Map<String, PreparedStatementValueObject<?>> getColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
String sourceDbTimezoneOffset) {
String sourceDbTimezoneOffset,
Map<String, Object> customTransformationResponse) {
Map<String, PreparedStatementValueObject<?>> response = new HashMap<>();
Set<String> sourcePKs = sourceTable.getPrimaryKeySet();
Set<String> customTransformColumns = null;
if (customTransformationResponse != null) {
customTransformColumns = customTransformationResponse.keySet();
}
for (Map.Entry<String, SourceColumnDefinition> entry : sourceTable.getColDefs().entrySet()) {
SourceColumnDefinition sourceColDef = entry.getValue();

Expand All @@ -317,7 +325,14 @@ private static Map<String, PreparedStatementValueObject<?>> getColumnValues(
}
String spannerColumnName = spannerColDef.getName();
PreparedStatementValueObject<?> columnValue;
if (keyValuesJson.has(spannerColumnName)) {
if (customTransformColumns != null
&& customTransformColumns.contains(sourceColDef.getName())) {
String cassandraType = sourceColDef.getType().getName().toLowerCase();
String columnName = spannerColDef.getName();
columnValue =
PreparedStatementValueObject.create(
cassandraType, customTransformationResponse.get(columnName));
} else if (keyValuesJson.has(spannerColumnName)) {
columnValue =
getMappedColumnValue(
spannerColDef, sourceColDef, keyValuesJson, sourceDbTimezoneOffset);
Expand All @@ -344,6 +359,7 @@ private static Map<String, PreparedStatementValueObject<?>> getColumnValues(
* @param newValuesJson the JSON object containing new values for columns.
* @param keyValuesJson the JSON object containing key values for columns.
* @param sourceDbTimezoneOffset the timezone offset of the source database.
* @param customTransformationResponse the user defined transformation.
* @return a map of primary key column names to their corresponding prepared statement value
* objects, or null if a required column is missing.
* <p>This method: 1. Iterates over the primary key definitions in the source table schema. 2.
Expand All @@ -357,10 +373,14 @@ private static Map<String, PreparedStatementValueObject<?>> getPkColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
String sourceDbTimezoneOffset) {
String sourceDbTimezoneOffset,
Map<String, Object> customTransformationResponse) {
Map<String, PreparedStatementValueObject<?>> response = new HashMap<>();
ColumnPK[] sourcePKs = sourceTable.getPrimaryKeys();

Set<String> customTransformColumns = null;
if (customTransformationResponse != null) {
customTransformColumns = customTransformationResponse.keySet();
}
for (ColumnPK currentSourcePK : sourcePKs) {
String colId = currentSourcePK.getColId();
SourceColumnDefinition sourceColDef = sourceTable.getColDefs().get(colId);
Expand All @@ -373,7 +393,14 @@ private static Map<String, PreparedStatementValueObject<?>> getPkColumnValues(
}
String spannerColumnName = spannerColDef.getName();
PreparedStatementValueObject<?> columnValue;
if (keyValuesJson.has(spannerColumnName)) {
if (customTransformColumns != null
&& customTransformColumns.contains(sourceColDef.getName())) {
String cassandraType = sourceColDef.getType().getName().toLowerCase();
String columnName = spannerColDef.getName();
columnValue =
PreparedStatementValueObject.create(
cassandraType, customTransformationResponse.get(columnName));
} else if (keyValuesJson.has(spannerColumnName)) {
columnValue =
getMappedColumnValue(
spannerColDef, sourceColDef, keyValuesJson, sourceDbTimezoneOffset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,13 @@ private static ByteBuffer parseBlobType(Object colValue) {
return ByteBuffer.wrap((byte[]) colValue);
} else if (colValue instanceof ByteBuffer) {
return (ByteBuffer) colValue;
} else {
String strVal = (String) colValue;
if (!strVal.matches("^[01]+$")) {
return ByteBuffer.wrap(java.util.Base64.getDecoder().decode(strVal));
}
}
return ByteBuffer.wrap(java.util.Base64.getDecoder().decode((String) colValue));
throw new IllegalArgumentException("Invalid colValue: " + colValue);
}

/**
Expand Down Expand Up @@ -322,18 +327,22 @@ private static Object handleSpannerColumnType(
String spannerType, String columnName, JSONObject valuesJson) {
try {
if (spannerType.contains("string")) {
return valuesJson.optString(columnName);
String value = valuesJson.optString(columnName);
return value.isEmpty() ? null : value;
} else if (spannerType.contains("bytes")) {
if (valuesJson.isNull(columnName)) {
return null;
}
String hexEncodedString = valuesJson.optString(columnName);
if (hexEncodedString.isEmpty()) {
return null;
}
return safeHandle(
() -> {
try {
return safeHandle(() -> convertBinaryEncodedStringToByteArray(hexEncodedString));
return safeHandle(() -> parseBlobType(hexEncodedString));
} catch (IllegalArgumentException e) {
return parseBlobType(hexEncodedString);
return convertBinaryEncodedStringToByteArray(hexEncodedString);
}
});
} else {
Expand Down
Loading
Loading