Skip to content

Commit

Permalink
Bulk Load CDK: Remove unnecessary initial marshaling
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Dec 20, 2024
1 parent 1bd65af commit 91ebdae
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.json.JsonToAirbyteValue
import io.airbyte.cdk.load.data.json.toJson
import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.CheckpointMessage.Checkpoint
import io.airbyte.cdk.load.message.CheckpointMessage.Stats
import io.airbyte.cdk.load.util.deserializeToNode
Expand Down Expand Up @@ -79,31 +78,45 @@ data class Meta(
}
}

sealed interface DestinationRecord : DestinationRecordDomainMessage
data class DestinationRecord(
override val stream: DestinationStream.Descriptor,
val message: AirbyteMessage,
val serialized: String,
val schema: AirbyteType
) : DestinationRecordDomainMessage {
override fun asProtocolMessage(): AirbyteMessage = message

fun asRecordSerialized(): DestinationRecordSerialized =
DestinationRecordSerialized(stream, serialized)
fun asRecordMarshaledToAirbyteValue(): DestinationRecordAirbyteValue {
return DestinationRecordAirbyteValue(
stream,
message.record.data.toAirbyteValue(schema),
message.record.emittedAt,
Meta(
message.record.meta?.changes?.map { Meta.Change(it.field, it.change, it.reason) }
?: emptyList()
)
)
}
}

/**
* Represents a record already in its serialized state. The intended use is for conveying records
* from stdin to the spill file, where reserialization is not necessary.
*/
data class DestinationRecordSerialized(
val stream: DestinationStream.Descriptor,
val serialized: String
)

/** Represents a record both deserialized AND marshaled to airbyte value. The marshaling */
data class DestinationRecordAirbyteValue(
override val stream: DestinationStream.Descriptor,
val stream: DestinationStream.Descriptor,
val data: AirbyteValue,
val emittedAtMs: Long,
val meta: Meta?,
val serialized: String,
) : DestinationRecord {
override fun asProtocolMessage(): AirbyteMessage =
AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(
AirbyteRecordMessage()
.withStream(stream.name)
.withNamespace(stream.namespace)
.withEmittedAt(emittedAtMs)
.withData(data.toJson())
.also {
if (meta != null) {
it.withMeta(meta.asProtocolObject())
}
}
)
}
)

data class DestinationFile(
override val stream: DestinationStream.Descriptor,
Expand Down Expand Up @@ -396,31 +409,7 @@ class DestinationMessageFactory(
)
)
} else {
DestinationRecordAirbyteValue(
stream = stream.descriptor,
data =
message.record.data?.let {
JsonToAirbyteValue().convert(it, stream.schema)
}
?: ObjectValue(linkedMapOf()),
emittedAtMs = message.record.emittedAt,
meta =
Meta(
changes =
message.record.meta
?.changes
?.map {
Meta.Change(
field = it.field,
change = it.change,
reason = it.reason,
)
}
?.toMutableList()
?: mutableListOf()
),
serialized = serialized
)
DestinationRecord(stream.descriptor, message, serialized, stream.schema)
}
}
AirbyteMessage.Type.TRACE -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import jakarta.inject.Singleton
*/
@Singleton
class ProtocolMessageDeserializer(
private val destinationMessageFactory: DestinationMessageFactory
private val destinationMessageFactory: DestinationMessageFactory,
) {
fun deserialize(
serialized: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sealed class DestinationStreamEvent : Sized
data class StreamRecordEvent(
val index: Long,
override val sizeBytes: Long,
val payload: DestinationRecordAirbyteValue
val payload: DestinationRecordSerialized
) : DestinationStreamEvent()

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
Expand Down Expand Up @@ -130,7 +131,7 @@ class DefaultProcessRecordsTask(
.takeWhile {
it !is DestinationRecordStreamComplete && it !is DestinationRecordStreamIncomplete
}
.map { it as DestinationRecordAirbyteValue }
.map { (it as DestinationRecord).asRecordMarshaledToAirbyteValue() }
.iterator()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
Expand Down Expand Up @@ -71,12 +71,12 @@ class DefaultInputConsumerTask(
val manager = syncManager.getStreamManager(stream)
val recordQueue = recordQueueSupplier.get(stream)
when (val message = reserved.value) {
is DestinationRecordAirbyteValue -> {
is DestinationRecord -> {
val wrapped =
StreamRecordEvent(
index = manager.countRecordIn(),
sizeBytes = sizeBytes,
payload = message
payload = message.asRecordSerialized()
)
recordQueue.publish(reserved.replace(wrapped))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class DestinationRecordAirbyteValueToAirbyteValueWithMetaTest {
)
val expected = LinkedHashMap(expectedMeta)
expected[Meta.COLUMN_NAME_DATA] = data
val mockRecord =
DestinationRecordAirbyteValue(stream.descriptor, data, emittedAtMs, Meta(), "test")
val mockRecord = DestinationRecordAirbyteValue(stream.descriptor, data, emittedAtMs, Meta())
val withMeta = mockRecord.dataWithAirbyteMeta(stream, flatten = false)
val uuid = withMeta.values.remove(Meta.COLUMN_NAME_AB_RAW_ID) as StringValue
Assertions.assertTrue(
Expand All @@ -65,8 +64,7 @@ class DestinationRecordAirbyteValueToAirbyteValueWithMetaTest {
)
val expected = LinkedHashMap(expectedMeta)
data.values.forEach { (name, value) -> expected[name] = value }
val mockRecord =
DestinationRecordAirbyteValue(stream.descriptor, data, emittedAtMs, Meta(), "test")
val mockRecord = DestinationRecordAirbyteValue(stream.descriptor, data, emittedAtMs, Meta())
val withMeta = mockRecord.dataWithAirbyteMeta(stream, flatten = true)
withMeta.values.remove(Meta.COLUMN_NAME_AB_RAW_ID)
Assertions.assertEquals(expected, withMeta.values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

package io.airbyte.cdk.load.task.implementor

import com.fasterxml.jackson.databind.node.JsonNodeFactory
import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.MockDestinationCatalogFactory
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MultiProducerChannel
Expand All @@ -21,6 +23,8 @@ import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.write
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifySequence
Expand Down Expand Up @@ -61,12 +65,21 @@ class ProcessRecordsTaskTest {
deserializer = mockk(relaxed = true)
coEvery { deserializer.deserialize(any()) } answers
{
DestinationRecordAirbyteValue(
DestinationRecord(
stream = MockDestinationCatalogFactory.stream1.descriptor,
data = IntegerValue(firstArg<String>().toLong()),
emittedAtMs = 0L,
meta = null,
serialized = firstArg<String>()
message =
AirbyteMessage()
.withRecord(
AirbyteRecordMessage()
.withEmittedAt(0L)
.withData(
JsonNodeFactory.instance.numberNode(
firstArg<String>().toLong()
)
)
),
serialized = "ignored",
schema = io.airbyte.cdk.load.data.IntegerType
)
}
processRecordsTaskFactory =
Expand Down Expand Up @@ -147,8 +160,10 @@ class ProcessRecordsTaskTest {
it.batch.groupId == groupId &&
it.batch.state == state &&
it.batch is MockBatch &&
(it.batch as MockBatch).records.map { record -> record.serialized }.toSet() ==
serializedRecords.toSet()
(it.batch as MockBatch)
.records
.map { record -> (record.data as IntegerValue).value.toString() }
.toSet() == serializedRecords.toSet()
}

// Verify the batch was *handled* 3 times but *published* ONLY when it is not complete AND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package io.airbyte.cdk.load.task.internal

import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordSerialized
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.mockk.coEvery
import io.mockk.impl.annotations.MockK
import java.io.ByteArrayInputStream
Expand Down Expand Up @@ -55,19 +57,20 @@ class ReservingDeserializingInputFlowTest {
coEvery { config.estimatedRecordMemoryOverheadRatio } returns RATIO
coEvery { deserializer.deserialize(any()) } answers
{
DestinationRecordAirbyteValue(
DestinationRecord(
stream,
NullValue,
0L,
null,
AirbyteMessage(),
firstArg<String>().reversed() + "!",
ObjectTypeWithoutSchema
)
}
val inputs = inputFlow.toList().map { it.first to it.second.value }
val inputs =
inputFlow.toList().map {
it.first to (it.second.value as DestinationRecord).asRecordSerialized()
}
val expectedOutputs =
records.map {
it.length.toLong() to
DestinationRecordAirbyteValue(stream, NullValue, 0L, null, it.reversed() + "!")
it.length.toLong() to DestinationRecordSerialized(stream, it.reversed() + "!")
}
assert(inputs == expectedOutputs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.MockDestinationCatalogFactory
import io.airbyte.cdk.load.command.MockDestinationConfiguration
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.file.DefaultSpillFileProvider
import io.airbyte.cdk.load.file.SpillFileProvider
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecordSerialized
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.DestinationStreamEventQueue
import io.airbyte.cdk.load.message.DestinationStreamQueueSupplier
Expand All @@ -27,7 +26,6 @@ import io.airbyte.cdk.load.state.TimeWindowTrigger
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.MockTaskLauncher
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
import io.airbyte.cdk.load.test.util.StubDestinationMessageFactory
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
Expand Down Expand Up @@ -94,9 +92,10 @@ class SpillToDiskTaskTest {
StreamRecordEvent(
3L,
2L,
StubDestinationMessageFactory.makeRecord(
MockDestinationCatalogFactory.stream1,
),
DestinationRecordSerialized(
MockDestinationCatalogFactory.stream1.descriptor,
""
)
)
// flush strategy returns true, so we flush
coEvery { flushStrategy.shouldFlush(any(), any(), any()) } returns true
Expand Down Expand Up @@ -133,9 +132,10 @@ class SpillToDiskTaskTest {
StreamRecordEvent(
3L,
2L,
StubDestinationMessageFactory.makeRecord(
MockDestinationCatalogFactory.stream1,
),
DestinationRecordSerialized(
MockDestinationCatalogFactory.stream1.descriptor,
""
)
)

// must publish 1 record message so range isn't empty
Expand Down Expand Up @@ -253,12 +253,9 @@ class SpillToDiskTaskTest {
index = index,
sizeBytes = Fixtures.SERIALIZED_SIZE_BYTES,
payload =
DestinationRecordAirbyteValue(
stream = MockDestinationCatalogFactory.stream1.descriptor,
data = NullValue,
emittedAtMs = 0,
meta = null,
serialized = "test"
DestinationRecordSerialized(
MockDestinationCatalogFactory.stream1.descriptor,
"",
),
),
),
Expand Down
Loading

0 comments on commit 91ebdae

Please sign in to comment.