diff --git a/src/main/kotlin/com/github/jershell/kbson/BsonFlexibleDecoder.kt b/src/main/kotlin/com/github/jershell/kbson/BsonFlexibleDecoder.kt index 3bb8f84..fd6123c 100644 --- a/src/main/kotlin/com/github/jershell/kbson/BsonFlexibleDecoder.kt +++ b/src/main/kotlin/com/github/jershell/kbson/BsonFlexibleDecoder.kt @@ -3,7 +3,9 @@ package com.github.jershell.kbson import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.* +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME @@ -18,18 +20,6 @@ abstract class FlexibleDecoder( override val serializersModule: SerializersModule, val configuration: Configuration ) : AbstractDecoder() { - - /** - * _id field comes always first in MongoDb. - * Sometimes you may need to check if this id is not already read by the [reader] - * - and then you may need to set the [alreadyReadId] to null in order to specify that is has been taken into account. - */ - open var alreadyReadId: Any? - get() = null - set(_) { - //do nothing - } - override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { return when (descriptor.kind) { StructureKind.CLASS -> { @@ -37,7 +27,7 @@ abstract class FlexibleDecoder( if (current == null || current == BsonType.DOCUMENT) { reader.readStartDocument() } - BsonFlexibleDecoder(reader, serializersModule, configuration, alreadyReadId) + BsonFlexibleDecoder(reader, serializersModule, configuration) } StructureKind.MAP -> { reader.readStartDocument() @@ -49,7 +39,7 @@ abstract class FlexibleDecoder( } is PolymorphicKind -> { reader.readStartDocument() - PolymorphismDecoder(reader, serializersModule, configuration, alreadyReadId) + PolymorphismDecoder(reader, serializersModule, configuration) } else -> this } @@ -136,7 +126,6 @@ class BsonFlexibleDecoder( reader: AbstractBsonReader, context: SerializersModule, configuration: Configuration, - override var alreadyReadId: Any? = null ) : FlexibleDecoder(reader, context, configuration) { //to handle not optional nullable properties @@ -191,14 +180,6 @@ class BsonFlexibleDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { initNotOptionalProperties(descriptor) - if (alreadyReadId != null) { - val idIndex = descriptor.getElementIndex("_id") - if (idIndex != UNKNOWN_NAME) { - return idIndex - } else { - alreadyReadId = null - } - } if (reader.state == State.TYPE) { reader.readBsonType() @@ -233,43 +214,54 @@ class BsonFlexibleDecoder( } return null } - - override fun decodeString(): String = - if (alreadyReadId != null) { - val result = alreadyReadId - alreadyReadId = null - result as? String ?: (result as ObjectId).toString() - } else { - super.decodeString() - } } private class PolymorphismDecoder( reader: AbstractBsonReader, val context: SerializersModule, configuration: Configuration, - override var alreadyReadId: Any? = null ) : FlexibleDecoder(reader, context, configuration) { private var decodeCount = 0 + private var discriminatorValue: String? = null @InternalSerializationApi override fun decodeSerializableValue(deserializer: DeserializationStrategy): T = - deserializer.deserialize(BsonFlexibleDecoder(reader, context, configuration, alreadyReadId)) + deserializer.deserialize(BsonFlexibleDecoder(reader, context, configuration)) override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + val classDiscriminator = descriptor.classDiscriminator() return when (decodeCount) { 0 -> { if (reader.state == State.TYPE) { reader.readBsonType() } - val fieldName = reader.readName() - if (fieldName == "_id") { - alreadyReadId = when (reader.currentBsonType) { - BsonType.OBJECT_ID -> reader.readObjectId() - BsonType.STRING -> reader.readString() - else -> error("only ObjectId or string are supported as _id for polymorphism decoder ") + + val mark = reader.mark + + while (discriminatorValue == null) { + when(reader.state) { + State.TYPE -> { + reader.readBsonType() + } + State.NAME -> { + val fieldName = reader.readName() + if (fieldName == classDiscriminator) { + discriminatorValue = reader.readString() + break + } else { + reader.skipValue() + } + } + else -> { + return CompositeDecoder.DECODE_DONE + } } } + if (discriminatorValue == null) { + error("Class discriminator field '$classDiscriminator' not found") + } + + mark.reset() decodeCount++ } 1 -> { @@ -278,6 +270,25 @@ private class PolymorphismDecoder( else -> CompositeDecoder.DECODE_DONE } } + + override fun decodeString(): String { + val currentDiscriminatorValue = discriminatorValue + return if (currentDiscriminatorValue != null) { + discriminatorValue = null + currentDiscriminatorValue + } else { + super.decodeString() + } + } + + private fun SerialDescriptor.classDiscriminator(): String { + for (annotation in this.annotations) { + if (annotation is BsonClassDiscriminator) { + return annotation.discriminator + } + } + return configuration.classDiscriminator + } } private class MapDecoder( @@ -391,4 +402,4 @@ private class ListDecoder( val nextType = reader.readBsonType() return if (nextType == BsonType.END_OF_DOCUMENT) CompositeDecoder.DECODE_DONE else index++ } -} \ No newline at end of file +} diff --git a/src/test/kotlin/com/github/jershell/kbson/KBsonTest.kt b/src/test/kotlin/com/github/jershell/kbson/KBsonTest.kt index fbbc971..131e31e 100644 --- a/src/test/kotlin/com/github/jershell/kbson/KBsonTest.kt +++ b/src/test/kotlin/com/github/jershell/kbson/KBsonTest.kt @@ -1306,6 +1306,18 @@ class KBsonTest { }) } + // class discriminator at a non-standard location + val doc6 = BsonDocument().apply { + append("payload", BsonDocument().apply { + append("_id", BsonObjectId(ObjectId("5d1777814e8c7b408a6ada73"))) + append("someData", BsonString("something")) + append( + conf.classDiscriminator, + BsonString("com.github.jershell.kbson.models.polymorph.SMessage.DataWithObjectId") + ) + }) + } + val polyBson = KBson(serializersModule = DefaultModule + pModule) val res1 = polyBson.parse(SealedWrapper.serializer(), doc1) @@ -1318,6 +1330,8 @@ class KBsonTest { val res5 = polyBson.parse(SealedWrapper.serializer(), doc5) + val res6 = polyBson.parse(SealedWrapper.serializer(), doc6) + assertTrue(res1.payload is SMessage.Error) assertTrue(res2.payload is SMessage.Loading) assertEquals(SealedWrapper(SMessage.Data(someData = "something")), res3) @@ -1330,6 +1344,14 @@ class KBsonTest { ) ), res5 ) + assertEquals( + SealedWrapper( + SMessage.DataWithObjectId( + someData = "something", + _id = ObjectId("5d1777814e8c7b408a6ada73") + ) + ), res6 + ) } @Test