Skip to content

Commit

Permalink
new interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed Jun 7, 2024
1 parent 026cf75 commit c972810
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,33 @@ import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag.getRawNamespaceOve
import io.airbyte.cdk.integrations.base.errors.messages.ErrorMessage.getErrorMessage
import io.airbyte.cdk.integrations.base.ssh.SshWrappedDestination
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer
import io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.state.FlushFailure
import io.airbyte.cdk.integrations.destination.jdbc.AbstractJdbcDestination
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcSqlGenerator
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcV1V2Migrator
import io.airbyte.cdk.integrations.destination.s3.AesCbcEnvelopeEncryption
import io.airbyte.cdk.integrations.destination.s3.AesCbcEnvelopeEncryptionBlobDecorator
import io.airbyte.cdk.integrations.destination.s3.EncryptionConfig
import io.airbyte.cdk.integrations.destination.s3.EncryptionConfig.Companion.fromJson
import io.airbyte.cdk.integrations.destination.s3.FileUploadFormat
import io.airbyte.cdk.integrations.destination.s3.NoEncryption
import io.airbyte.cdk.integrations.destination.s3.S3BaseChecks.attemptS3WriteAndDelete
import io.airbyte.cdk.integrations.destination.s3.S3DestinationConfig
import io.airbyte.cdk.integrations.destination.s3.S3StorageOperations
import io.airbyte.cdk.integrations.destination.staging.StagingConsumerFactory.Companion.builder
import io.airbyte.cdk.integrations.destination.staging.operation.StagingStreamOperations
import io.airbyte.commons.exceptions.ConnectionErrorException
import io.airbyte.commons.json.Jsons.deserialize
import io.airbyte.commons.json.Jsons.emptyObject
import io.airbyte.commons.json.Jsons.jsonNode
import io.airbyte.commons.resources.MoreResources.readResource
import io.airbyte.integrations.base.destination.operation.DefaultFlush
import io.airbyte.integrations.base.destination.operation.DefaultSyncOperation
import io.airbyte.integrations.base.destination.typing_deduping.CatalogParser
import io.airbyte.integrations.base.destination.typing_deduping.DefaultTyperDeduper
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler
Expand All @@ -52,8 +61,10 @@ import io.airbyte.integrations.base.destination.typing_deduping.SqlGenerator
import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduper
import io.airbyte.integrations.base.destination.typing_deduping.migrators.Migration
import io.airbyte.integrations.destination.redshift.constants.RedshiftDestinationConstants
import io.airbyte.integrations.destination.redshift.operation.RedshiftStagingStorageOperation
import io.airbyte.integrations.destination.redshift.operations.RedshiftS3StagingSqlOperations
import io.airbyte.integrations.destination.redshift.operations.RedshiftSqlOperations
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftDV2Migration
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftDestinationHandler
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftRawTableAirbyteMetaMigration
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftSqlGenerator
Expand All @@ -66,6 +77,7 @@ import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
import io.airbyte.protocol.models.v0.ConnectorSpecification
import java.time.Duration
import java.util.Optional
import java.util.concurrent.Executors
import java.util.function.Consumer
import javax.sql.DataSource
import org.apache.commons.lang3.NotImplementedException
Expand Down Expand Up @@ -236,8 +248,14 @@ class RedshiftStagingS3Destination :
sqlGenerator: SqlGenerator,
destinationHandler: DestinationHandler<RedshiftState>
): List<Migration<RedshiftState>> {
return listOf<Migration<RedshiftState>>(
RedshiftRawTableAirbyteMetaMigration(database, databaseName)
return listOf(
RedshiftDV2Migration(
namingResolver,
database,
databaseName,
sqlGenerator as RedshiftSqlGenerator,
),
RedshiftRawTableAirbyteMetaMigration(database, databaseName),
)
}

Expand Down Expand Up @@ -285,7 +303,6 @@ class RedshiftStagingS3Destination :

val sqlGenerator = RedshiftSqlGenerator(namingResolver, config)
val parsedCatalog: ParsedCatalog
val typerDeduper: TyperDeduper
val database = getDatabase(getDataSource(config))
val databaseName = config[JdbcUtils.DATABASE_KEY].asText()
val catalogParser: CatalogParser
Expand All @@ -300,54 +317,61 @@ class RedshiftStagingS3Destination :
val redshiftDestinationHandler =
RedshiftDestinationHandler(databaseName, database, rawNamespace)
parsedCatalog = catalogParser.parseCatalog(catalog)
val migrator = JdbcV1V2Migrator(namingResolver, database, databaseName)
val v2TableMigrator = NoopV2TableMigrator()
val disableTypeDedupe =
config.has(DISABLE_TYPE_DEDUPE) && config[DISABLE_TYPE_DEDUPE].asBoolean(false)
val redshiftMigrations: List<Migration<RedshiftState>> =
getMigrations(database, databaseName, sqlGenerator, redshiftDestinationHandler)
typerDeduper =
if (disableTypeDedupe) {
NoOpTyperDeduperWithV1V2Migrations(
sqlGenerator,
redshiftDestinationHandler,
parsedCatalog,
migrator,
v2TableMigrator,
redshiftMigrations
)
} else {
DefaultTyperDeduper(
sqlGenerator,
redshiftDestinationHandler,
parsedCatalog,
migrator,
v2TableMigrator,
redshiftMigrations
)
}

return builder(
outputRecordCollector,
database,
RedshiftS3StagingSqlOperations(
namingResolver,
s3Config.getS3Client(),
s3Config,
encryptionConfig
),
namingResolver,
config,
catalog,
isPurgeStagingData(s3Options),
typerDeduper,
parsedCatalog,
defaultNamespace,
JavaBaseConstants.DestinationColumns.V2_WITH_META
val s3StorageOperations =
S3StorageOperations(namingResolver, s3Config.getS3Client(), s3Config)
val keyEncryptingKey: ByteArray?
if (encryptionConfig is AesCbcEnvelopeEncryption) {
s3StorageOperations.addBlobDecorator(
AesCbcEnvelopeEncryptionBlobDecorator(encryptionConfig.key)
)
.setDataTransformer(getDataTransformer(parsedCatalog, defaultNamespace))
.build()
.createAsync()
keyEncryptingKey = encryptionConfig.key
} else {
keyEncryptingKey = null
}

val redshiftStagingStorageOperation = RedshiftStagingStorageOperation(
// S3DestinationConfig.getS3DestinationConfig always sets a nonnull bucket path
// TODO mark bucketPath as non-nullable
s3Config.bucketPath!!,
isPurgeStagingData(s3Options),
s3StorageOperations,
sqlGenerator,
redshiftDestinationHandler,
keyEncryptingKey,
)
val syncOperation = DefaultSyncOperation(
parsedCatalog,
redshiftDestinationHandler,
defaultNamespace,
{ initialStatus, disableTD ->
StagingStreamOperations(
redshiftStagingStorageOperation,
initialStatus,
FileUploadFormat.CSV,
JavaBaseConstants.DestinationColumns.V2_WITH_META,
disableTD
)
},
redshiftMigrations,
disableTypeDedupe,
)
return AsyncStreamConsumer(
outputRecordCollector,
onStart = {},
onClose = {_, streamSyncSummaries -> syncOperation.finalizeStreams(streamSyncSummaries)},
onFlush = DefaultFlush(OPTIMAL_FLUSH_BATCH_SIZE, syncOperation),
catalog,
BufferManager(bufferMemoryLimit),
Optional.ofNullable(defaultNamespace),
FlushFailure(),
Executors.newFixedThreadPool(5),
AirbyteMessageDeserializer(getDataTransformer(parsedCatalog, defaultNamespace)),
)
}

private fun isPurgeStagingData(config: JsonNode?): Boolean {
Expand All @@ -367,6 +391,9 @@ class RedshiftStagingS3Destination :
"com.amazon.redshift.ssl.NonValidatingFactory"
)

private const val OPTIMAL_FLUSH_BATCH_SIZE: Long = 50 * 1024 * 1024
private val bufferMemoryLimit: Long = (Runtime.getRuntime().maxMemory() * 0.5).toLong()

private fun sshWrappedDestination(): Destination {
return SshWrappedDestination(
RedshiftStagingS3Destination(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package io.airbyte.integrations.destination.redshift.operation

import io.airbyte.cdk.integrations.base.JavaBaseConstants
import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer
import io.airbyte.cdk.integrations.destination.s3.AesCbcEnvelopeEncryptionBlobDecorator
import io.airbyte.cdk.integrations.destination.s3.S3StorageOperations
import io.airbyte.cdk.integrations.destination.staging.SerialStagingConsumerFactory
import io.airbyte.integrations.base.destination.operation.StorageOperation
import io.airbyte.integrations.base.destination.typing_deduping.Sql
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig
import io.airbyte.integrations.base.destination.typing_deduping.StreamId
import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil
import io.airbyte.integrations.destination.redshift.RedshiftSQLNameTransformer
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftDestinationHandler
import io.airbyte.integrations.destination.redshift.typing_deduping.RedshiftSqlGenerator
import io.airbyte.protocol.models.v0.DestinationSyncMode
import io.github.oshai.kotlinlogging.KotlinLogging
import java.time.Instant
import java.time.ZoneOffset
import java.time.ZonedDateTime
import java.util.Optional
import java.util.UUID

private val log = KotlinLogging.logger {}

/**
* @param keyEncryptingKey The KEK to use when writing to S3, or null if encryption is disabled.
* If this parameter is nonnull, then `s3StorageOperations` MUST have an
* [AesCbcEnvelopeEncryptionBlobDecorator] added (via `s3StorageOperations#addBlobDecorator`).
*/
class RedshiftStagingStorageOperation(
private val bucketPath: String,
private val keepStagingFiles: Boolean,
private val s3StorageOperations: S3StorageOperations,
private val sqlGenerator: RedshiftSqlGenerator,
private val destinationHandler: RedshiftDestinationHandler,
private val keyEncryptingKey: ByteArray?,
private val connectionId: UUID = SerialStagingConsumerFactory.RANDOM_CONNECTION_ID,
private val writeDatetime: ZonedDateTime = Instant.now().atZone(ZoneOffset.UTC),
): StorageOperation<SerializableBuffer> {
override fun prepareStage(streamId: StreamId, destinationSyncMode: DestinationSyncMode) {
// create raw table
destinationHandler.execute(Sql.of(createRawTableQuery(streamId)))
if (destinationSyncMode == DestinationSyncMode.OVERWRITE) {
destinationHandler.execute(Sql.of(truncateRawTableQuery(streamId)))
}
// create bucket for staging files
s3StorageOperations.createBucketIfNotExists()
}

override fun writeToStage(streamId: StreamId, data: SerializableBuffer) {
val objectPath: String = getStagingPath(streamId)
log.info {
"Uploading records to for ${streamId.rawNamespace}.${streamId.rawName} to path $objectPath"
}
s3StorageOperations.uploadRecordsToBucket(data, streamId.rawNamespace, objectPath)
}

override fun cleanupStage(streamId: StreamId) {
if (keepStagingFiles) return
val stagingRootPath = getStagingPath(streamId)
log.info { "Cleaning up staging path at $stagingRootPath" }
s3StorageOperations.dropBucketObject(stagingRootPath)
}

override fun createFinalTable(streamConfig: StreamConfig, suffix: String, replace: Boolean) {
destinationHandler.execute(sqlGenerator.createTable(streamConfig, suffix, replace))
}

override fun softResetFinalTable(streamConfig: StreamConfig) {
TyperDeduperUtil.executeSoftReset(
sqlGenerator = sqlGenerator,
destinationHandler = destinationHandler,
streamConfig,
)
}

override fun overwriteFinalTable(streamConfig: StreamConfig, tmpTableSuffix: String) {
if (tmpTableSuffix.isNotBlank()) {
log.info {
"Overwriting table ${streamConfig.id.finalTableId(RedshiftSqlGenerator.QUOTE)} with ${
streamConfig.id.finalTableId(
RedshiftSqlGenerator.QUOTE,
tmpTableSuffix,
)
}"
}
destinationHandler.execute(
sqlGenerator.overwriteFinalTable(streamConfig.id, tmpTableSuffix)
)
}
}

override fun typeAndDedupe(
streamConfig: StreamConfig,
maxProcessedTimestamp: Optional<Instant>,
finalTableSuffix: String
) {
TyperDeduperUtil.executeTypeAndDedupe(
sqlGenerator = sqlGenerator,
destinationHandler = destinationHandler,
streamConfig,
maxProcessedTimestamp,
finalTableSuffix,
)
}

private fun getStagingPath(streamId: StreamId): String {
val prefix =
if (bucketPath.isEmpty()) ""
else bucketPath + (if (bucketPath.endsWith("/")) "" else "/")
return nameTransformer.applyDefaultCase(
String.format(
"%s%s/%s_%02d_%02d_%02d_%s/",
prefix,
nameTransformer.applyDefaultCase(
// I have no idea why we're doing this.
// streamId.rawName already has been passed through the name transformer.
nameTransformer.convertStreamName(streamId.rawName)
),
writeDatetime.year,
writeDatetime.monthValue,
writeDatetime.dayOfMonth,
writeDatetime.hour,
connectionId
)
)
}

companion object {
private val nameTransformer = RedshiftSQLNameTransformer()

private fun createRawTableQuery(streamId: StreamId): String {
return """
CREATE TABLE IF NOT EXISTS "${streamId.rawNamespace}"."${streamId.rawName}"
${JavaBaseConstants.COLUMN_NAME_AB_RAW_ID} VARCHAR(36),
${JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT} TIMESTAMP_WITH_TIMEZONE DEFAULT GETDATE,
${JavaBaseConstants.COLUMN_NAME_AB_LOADED_AT} TIMESTAMP_WITH_TIMEZONE,
${JavaBaseConstants.COLUMN_NAME_DATA} SUPER NOT NULL,
${JavaBaseConstants.COLUMN_NAME_AB_META} SUPER NULL;
""".trimIndent()
}

private fun truncateRawTableQuery(
streamId: StreamId,
): String {
return String.format("""TRUNCATE TABLE "%s"."%s";\n""", streamId.rawNamespace, streamId.rawName)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -269,25 +269,23 @@ open class RedshiftSqlGenerator(
* Return ROW_NUMBER() OVER (PARTITION BY primaryKeys ORDER BY cursor DESC NULLS LAST,
* _airbyte_extracted_at DESC)
*
* @param primaryKeys
* @param cursor
* @param primaryKey
* @param cursorField
* @return
*/
override fun getRowNumber(primaryKeys: List<ColumnId>, cursor: Optional<ColumnId>): Field<Int> {
override fun getRowNumber(primaryKey: List<ColumnId>, cursorField: Optional<ColumnId>): Field<Int> {
// literally identical to postgres's getRowNumber implementation, changes here probably
// should
// be reflected there
val primaryKeyFields =
if (primaryKeys != null)
primaryKeys
.stream()
.map { columnId: ColumnId -> DSL.field(DSL.quotedName(columnId.name)) }
.collect(Collectors.toList())
else ArrayList()
primaryKey
.stream()
.map { columnId: ColumnId -> DSL.field(DSL.quotedName(columnId.name)) }
.collect(Collectors.toList())
val orderedFields: MutableList<Field<*>> = ArrayList()
// We can still use Jooq's field to get the quoted name with raw sql templating.
// jooq's .desc returns SortField<?> instead of Field<?> and NULLS LAST doesn't work with it
cursor.ifPresent { columnId: ColumnId ->
cursorField.ifPresent { columnId: ColumnId ->
orderedFields.add(
DSL.field("{0} desc NULLS LAST", DSL.field(DSL.quotedName(columnId.name)))
)
Expand Down Expand Up @@ -331,6 +329,7 @@ open class RedshiftSqlGenerator(
companion object {
const val CASE_STATEMENT_SQL_TEMPLATE: String = "CASE WHEN {0} THEN {1} ELSE {2} END "
const val CASE_STATEMENT_NO_ELSE_SQL_TEMPLATE: String = "CASE WHEN {0} THEN {1} END "
const val QUOTE: String = "\""

private const val AIRBYTE_META_COLUMN_CHANGES_KEY = "changes"

Expand Down
Loading

0 comments on commit c972810

Please sign in to comment.