Skip to content

Commit

Permalink
Bulk Load CDK: Remove interfaces from InputFlow, migrate tests to moc…
Browse files Browse the repository at this point in the history
…kk (#49974)
  • Loading branch information
johnny-schmidt authored Dec 20, 2024
1 parent c6e6f78 commit 1bd65af
Show file tree
Hide file tree
Showing 31 changed files with 410 additions and 503 deletions.
9 changes: 9 additions & 0 deletions airbyte-cdk/bulk/core/load/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ tasks.named('check').configure {
dependsOn integrationTest
}

project.tasks.matching {
it.name == 'spotbugsIntegrationTestLegacy' ||
it.name == 'spotbugsIntegrationTest' ||
it.name == 'spotbugsTest' ||
it.name == 'spotbugsMain'
}.configureEach {
enabled = false
}

test {
systemProperties(["mockk.junit.extension.requireParallelTesting":"true"])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed
import io.airbyte.cdk.load.test.util.OutputRecord
Expand All @@ -38,7 +38,7 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
override val groupId: String? = null
}

data class LocalBatch(val records: List<DestinationRecord>) : MockBatch() {
data class LocalBatch(val records: List<DestinationRecordAirbyteValue>) : MockBatch() {
override val state = Batch.State.STAGED
}
data class LocalFileBatch(val file: DestinationFile) : MockBatch() {
Expand Down Expand Up @@ -72,7 +72,7 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
}

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
records: Iterator<DestinationRecordAirbyteValue>,
totalSizeBytes: Long,
endOfStream: Boolean
): Batch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.Meta
import java.util.*

Expand Down Expand Up @@ -59,5 +59,7 @@ fun Pair<AirbyteValue, List<Meta.Change>>.withAirbyteMeta(
DestinationRecordToAirbyteValueWithMeta(stream, flatten)
.convert(first, emittedAtMs, Meta(second))

fun DestinationRecord.dataWithAirbyteMeta(stream: DestinationStream, flatten: Boolean = false) =
DestinationRecordToAirbyteValueWithMeta(stream, flatten).convert(data, emittedAtMs, meta)
fun DestinationRecordAirbyteValue.dataWithAirbyteMeta(
stream: DestinationStream,
flatten: Boolean = false
) = DestinationRecordToAirbyteValueWithMeta(stream, flatten).convert(data, emittedAtMs, meta)
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ data class Meta(
}
}

data class DestinationRecord(
sealed interface DestinationRecord : DestinationRecordDomainMessage

data class DestinationRecordAirbyteValue(
override val stream: DestinationStream.Descriptor,
val data: AirbyteValue,
val emittedAtMs: Long,
val meta: Meta?,
val serialized: String,
) : DestinationStreamAffinedMessage {
) : DestinationRecord {
override fun asProtocolMessage(): AirbyteMessage =
AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
Expand Down Expand Up @@ -348,7 +350,10 @@ class DestinationMessageFactory(
private val catalog: DestinationCatalog,
@Value("\${airbyte.file-transfer.enabled}") private val fileTransferEnabled: Boolean,
) {
fun fromAirbyteMessage(message: AirbyteMessage, serialized: String): DestinationMessage {
fun fromAirbyteMessage(
message: AirbyteMessage,
serialized: String,
): DestinationMessage {
fun toLong(value: Any?, name: String): Long? {
return value?.let {
when (it) {
Expand Down Expand Up @@ -391,7 +396,7 @@ class DestinationMessageFactory(
)
)
} else {
DestinationRecord(
DestinationRecordAirbyteValue(
stream = stream.descriptor,
data =
message.record.data?.let {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,17 @@ import io.airbyte.cdk.load.util.deserializeToClass
import io.airbyte.protocol.models.v0.AirbyteMessage
import jakarta.inject.Singleton

interface Deserializer<T : Any> {
fun deserialize(serialized: String): T
}

/**
* Converts the internal @[AirbyteMessage] to the internal @[DestinationMessage] Ideally, this would
* not use protocol messages at all, but rather a specialized deserializer for routing.
*/
@Singleton
class DefaultDestinationMessageDeserializer(private val messageFactory: DestinationMessageFactory) :
Deserializer<DestinationMessage> {

override fun deserialize(serialized: String): DestinationMessage {
class ProtocolMessageDeserializer(
private val destinationMessageFactory: DestinationMessageFactory
) {
fun deserialize(
serialized: String,
): DestinationMessage {
val airbyteMessage =
try {
serialized.deserializeToClass(AirbyteMessage::class.java)
Expand Down Expand Up @@ -49,13 +47,9 @@ class DefaultDestinationMessageDeserializer(private val messageFactory: Destinat
)
}

val internalDestinationMessage =
try {
messageFactory.fromAirbyteMessage(airbyteMessage, serialized)
} catch (t: Throwable) {
throw RuntimeException("Failed to convert AirbyteMessage to DestinationMessage", t)
}

return internalDestinationMessage
return destinationMessageFactory.fromAirbyteMessage(
airbyteMessage,
serialized,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sealed class DestinationStreamEvent : Sized
data class StreamRecordEvent(
val index: Long,
override val sizeBytes: Long,
val record: DestinationRecord
val payload: DestinationRecordAirbyteValue
) : DestinationStreamEvent()

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueWriter
Expand All @@ -29,7 +28,7 @@ import io.airbyte.cdk.load.task.implementor.TeardownTaskFactory
import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory
import io.airbyte.cdk.load.task.internal.FlushTickTask
import io.airbyte.cdk.load.task.internal.InputConsumerTaskFactory
import io.airbyte.cdk.load.task.internal.SizedInputFlow
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory
import io.airbyte.cdk.load.task.internal.TimedForcedCheckpointFlushTask
import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask
Expand Down Expand Up @@ -125,7 +124,7 @@ class DefaultDestinationTaskLauncher(
@Value("\${airbyte.file-transfer.enabled}") private val fileTransferEnabled: Boolean,

// Input Consumer requirements
private val inputFlow: SizedInputFlow<Reserved<DestinationMessage>>,
private val inputFlow: ReservingDeserializingInputFlow,
private val recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.Deserializer
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
Expand Down Expand Up @@ -47,7 +46,7 @@ interface ProcessRecordsTask : KillableScope
class DefaultProcessRecordsTask(
private val config: DestinationConfiguration,
private val taskLauncher: DestinationTaskLauncher,
private val deserializer: Deserializer<DestinationMessage>,
private val deserializer: ProtocolMessageDeserializer,
private val syncManager: SyncManager,
private val diskManager: ReservationManager,
private val inputQueue: MessageQueue<FileAggregateMessage>,
Expand All @@ -70,7 +69,7 @@ class DefaultProcessRecordsTask(
file.localFile.inputStream().use {
val records =
if (file.isEmpty) {
emptyList<DestinationRecord>().listIterator()
emptyList<DestinationRecordAirbyteValue>().listIterator()
} else {
it.toRecordIterator()
}
Expand All @@ -91,7 +90,11 @@ class DefaultProcessRecordsTask(
log.info { "Forcing finalization of all accumulators." }
accumulators.forEach { (streamDescriptor, acc) ->
val finalBatch =
acc.processRecords(emptyList<DestinationRecord>().listIterator(), 0, true)
acc.processRecords(
emptyList<DestinationRecordAirbyteValue>().listIterator(),
0,
true
)
handleBatch(streamDescriptor, finalBatch, null)
}
}
Expand All @@ -113,7 +116,7 @@ class DefaultProcessRecordsTask(
}
}

private fun InputStream.toRecordIterator(): Iterator<DestinationRecord> {
private fun InputStream.toRecordIterator(): Iterator<DestinationRecordAirbyteValue> {
return lineSequence()
.map {
when (val message = deserializer.deserialize(it)) {
Expand All @@ -127,7 +130,7 @@ class DefaultProcessRecordsTask(
.takeWhile {
it !is DestinationRecordStreamComplete && it !is DestinationRecordStreamIncomplete
}
.map { it as DestinationRecord }
.map { it as DestinationRecordAirbyteValue }
.iterator()
}
}
Expand All @@ -147,7 +150,7 @@ data class FileAggregateMessage(
@Secondary
class DefaultProcessRecordsTaskFactory(
private val config: DestinationConfiguration,
private val deserializer: Deserializer<DestinationMessage>,
private val deserializer: ProtocolMessageDeserializer,
private val syncManager: SyncManager,
@Named("diskManager") private val diskManager: ReservationManager,
@Named("fileAggregateQueue") private val inputQueue: MessageQueue<FileAggregateMessage>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
Expand Down Expand Up @@ -55,7 +54,7 @@ interface InputConsumerTask : KillableScope
@Secondary
class DefaultInputConsumerTask(
private val catalog: DestinationCatalog,
private val inputFlow: SizedInputFlow<Reserved<DestinationMessage>>,
private val inputFlow: ReservingDeserializingInputFlow,
private val recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
Expand All @@ -72,12 +71,12 @@ class DefaultInputConsumerTask(
val manager = syncManager.getStreamManager(stream)
val recordQueue = recordQueueSupplier.get(stream)
when (val message = reserved.value) {
is DestinationRecord -> {
is DestinationRecordAirbyteValue -> {
val wrapped =
StreamRecordEvent(
index = manager.countRecordIn(),
sizeBytes = sizeBytes,
record = message
payload = message
)
recordQueue.publish(reserved.replace(wrapped))
}
Expand Down Expand Up @@ -193,7 +192,7 @@ class DefaultInputConsumerTask(
interface InputConsumerTaskFactory {
fun make(
catalog: DestinationCatalog,
inputFlow: SizedInputFlow<Reserved<DestinationMessage>>,
inputFlow: ReservingDeserializingInputFlow,
recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
Expand All @@ -207,7 +206,7 @@ class DefaultInputConsumerTaskFactory(private val syncManager: SyncManager) :
InputConsumerTaskFactory {
override fun make(
catalog: DestinationCatalog,
inputFlow: SizedInputFlow<Reserved<DestinationMessage>>,
inputFlow: ReservingDeserializingInputFlow,
recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
package io.airbyte.cdk.load.task.internal

import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.message.Deserializer
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.Reserved
import io.github.oshai.kotlinlogging.KotlinLogging
Expand All @@ -16,17 +16,18 @@ import java.io.InputStream
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector

interface SizedInputFlow<T> : Flow<Pair<Long, T>>

abstract class ReservingDeserializingInputFlow<T : Any> : SizedInputFlow<Reserved<T>> {
@Singleton
class ReservingDeserializingInputFlow(
val config: DestinationConfiguration,
val deserializer: ProtocolMessageDeserializer,
@Named("memoryManager") val memoryManager: ReservationManager,
val inputStream: InputStream,
) : Flow<Pair<Long, Reserved<DestinationMessage>>> {
val log = KotlinLogging.logger {}

abstract val config: DestinationConfiguration
abstract val deserializer: Deserializer<T>
abstract val memoryManager: ReservationManager
abstract val inputStream: InputStream

override suspend fun collect(collector: FlowCollector<Pair<Long, Reserved<T>>>) {
override suspend fun collect(
collector: FlowCollector<Pair<Long, Reserved<DestinationMessage>>>
) {
log.info {
"Reserved ${memoryManager.totalCapacityBytes/1024}mb memory for input processing"
}
Expand All @@ -50,11 +51,3 @@ abstract class ReservingDeserializingInputFlow<T : Any> : SizedInputFlow<Reserve
log.info { "Finished processing input" }
}
}

@Singleton
class DefaultInputFlow(
override val config: DestinationConfiguration,
override val deserializer: Deserializer<DestinationMessage>,
@Named("memoryManager") override val memoryManager: ReservationManager,
override val inputStream: InputStream
) : ReservingDeserializingInputFlow<DestinationMessage>()
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class DefaultSpillToDiskTask(
diskManager.reserve(event.sizeBytes)

// write to disk
outputStream.write(event.record.serialized)
outputStream.write(event.payload.serialized)
outputStream.write("\n")

// calculate whether we should flush
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package io.airbyte.cdk.load.write
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

Expand Down Expand Up @@ -56,7 +56,7 @@ interface StreamLoader : BatchAccumulator {

interface BatchAccumulator {
suspend fun processRecords(
records: Iterator<DestinationRecord>,
records: Iterator<DestinationRecordAirbyteValue>,
totalSizeBytes: Long,
endOfStream: Boolean = false
): Batch =
Expand Down
Loading

0 comments on commit 1bd65af

Please sign in to comment.