Skip to content

Commit dc6c38b

Browse files
authored
Support discriminators not being the first field when decoding (mongodb#1324)
JAVA-5304
1 parent 7c37741 commit dc6c38b

File tree

4 files changed

+167
-30
lines changed

4 files changed

+167
-30
lines changed

bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import kotlinx.serialization.modules.SerializersModule
3131
import org.bson.AbstractBsonReader
3232
import org.bson.BsonInvalidOperationException
3333
import org.bson.BsonReader
34+
import org.bson.BsonReaderMark
3435
import org.bson.BsonType
3536
import org.bson.BsonValue
3637
import org.bson.codecs.BsonValueCodec
@@ -68,6 +69,20 @@ internal open class DefaultBsonDecoder(
6869
val validKeyKinds = setOf(PrimitiveKind.STRING, PrimitiveKind.CHAR, SerialKind.ENUM)
6970
val bsonValueCodec = BsonValueCodec()
7071
const val UNKNOWN_INDEX = -10
72+
fun validateCurrentBsonType(
73+
reader: AbstractBsonReader,
74+
expectedType: BsonType,
75+
descriptor: SerialDescriptor,
76+
actualType: (descriptor: SerialDescriptor) -> String = { it.kind.toString() }
77+
) {
78+
reader.currentBsonType?.let {
79+
if (it != expectedType) {
80+
throw SerializationException(
81+
"Invalid data for `${actualType(descriptor)}` expected a bson " +
82+
"${expectedType.name.lowercase()} found: ${reader.currentBsonType}")
83+
}
84+
}
85+
}
7186
}
7287

7388
private fun initElementMetadata(descriptor: SerialDescriptor) {
@@ -119,29 +134,14 @@ internal open class DefaultBsonDecoder(
119134

120135
@Suppress("ReturnCount")
121136
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
122-
when (descriptor.kind) {
123-
is StructureKind.LIST -> {
124-
reader.readStartArray()
125-
return BsonArrayDecoder(reader, serializersModule, configuration)
126-
}
127-
is PolymorphicKind -> {
128-
reader.readStartDocument()
129-
return PolymorphicDecoder(reader, serializersModule, configuration)
130-
}
137+
return when (descriptor.kind) {
138+
is StructureKind.LIST -> BsonArrayDecoder(descriptor, reader, serializersModule, configuration)
139+
is PolymorphicKind -> PolymorphicDecoder(descriptor, reader, serializersModule, configuration)
131140
is StructureKind.CLASS,
132-
StructureKind.OBJECT -> {
133-
val current = reader.currentBsonType
134-
if (current == null || current == BsonType.DOCUMENT) {
135-
reader.readStartDocument()
136-
}
137-
}
138-
is StructureKind.MAP -> {
139-
reader.readStartDocument()
140-
return BsonDocumentDecoder(reader, serializersModule, configuration)
141-
}
141+
StructureKind.OBJECT -> BsonDocumentDecoder(descriptor, reader, serializersModule, configuration)
142+
is StructureKind.MAP -> MapDecoder(descriptor, reader, serializersModule, configuration)
142143
else -> throw SerializationException("Primitives are not supported at top-level")
143144
}
144-
return DefaultBsonDecoder(reader, serializersModule, configuration)
145145
}
146146

147147
override fun endStructure(descriptor: SerialDescriptor) {
@@ -194,10 +194,17 @@ internal open class DefaultBsonDecoder(
194194

195195
@OptIn(ExperimentalSerializationApi::class)
196196
private class BsonArrayDecoder(
197+
descriptor: SerialDescriptor,
197198
reader: AbstractBsonReader,
198199
serializersModule: SerializersModule,
199200
configuration: BsonConfiguration
200201
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
202+
203+
init {
204+
validateCurrentBsonType(reader, BsonType.ARRAY, descriptor)
205+
reader.readStartArray()
206+
}
207+
201208
private var index = 0
202209
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
203210
val nextType = reader.readBsonType()
@@ -208,18 +215,46 @@ private class BsonArrayDecoder(
208215

209216
@OptIn(ExperimentalSerializationApi::class)
210217
private class PolymorphicDecoder(
218+
descriptor: SerialDescriptor,
211219
reader: AbstractBsonReader,
212220
serializersModule: SerializersModule,
213221
configuration: BsonConfiguration
214222
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
215223
private var index = 0
224+
private var mark: BsonReaderMark?
216225

217-
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T =
218-
deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration))
226+
init {
227+
mark = reader.mark
228+
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName }
229+
reader.readStartDocument()
230+
}
231+
232+
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
233+
mark?.let {
234+
it.reset()
235+
mark = null
236+
}
237+
return deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration))
238+
}
219239

220240
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
241+
var found = false
221242
return when (index) {
222-
0 -> index++
243+
0 -> {
244+
while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
245+
if (reader.readName() == configuration.classDiscriminator) {
246+
found = true
247+
break
248+
}
249+
reader.skipValue()
250+
}
251+
if (!found) {
252+
throw SerializationException(
253+
"Missing required discriminator field `${configuration.classDiscriminator}` " +
254+
"for polymorphic class: `${descriptor.serialName}`.")
255+
}
256+
index++
257+
}
223258
1 -> index++
224259
else -> DECODE_DONE
225260
}
@@ -228,6 +263,20 @@ private class PolymorphicDecoder(
228263

229264
@OptIn(ExperimentalSerializationApi::class)
230265
private class BsonDocumentDecoder(
266+
descriptor: SerialDescriptor,
267+
reader: AbstractBsonReader,
268+
serializersModule: SerializersModule,
269+
configuration: BsonConfiguration
270+
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
271+
init {
272+
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName }
273+
reader.readStartDocument()
274+
}
275+
}
276+
277+
@OptIn(ExperimentalSerializationApi::class)
278+
private class MapDecoder(
279+
descriptor: SerialDescriptor,
231280
reader: AbstractBsonReader,
232281
serializersModule: SerializersModule,
233282
configuration: BsonConfiguration
@@ -236,6 +285,11 @@ private class BsonDocumentDecoder(
236285
private var index = 0
237286
private var isKey = false
238287

288+
init {
289+
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor)
290+
reader.readStartDocument()
291+
}
292+
239293
override fun decodeString(): String {
240294
return if (isKey) {
241295
reader.readName()

bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ import org.bson.codecs.kotlinx.samples.DataClassOpen
3535
import org.bson.codecs.kotlinx.samples.DataClassOpenA
3636
import org.bson.codecs.kotlinx.samples.DataClassOpenB
3737
import org.bson.codecs.kotlinx.samples.DataClassParameterized
38+
import org.bson.codecs.kotlinx.samples.DataClassSealedInterface
3839
import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues
40+
import org.bson.codecs.kotlinx.samples.SealedInterface
3941
import org.bson.conversions.Bson
42+
import org.bson.json.JsonReader
43+
import org.bson.types.ObjectId
4044
import org.junit.jupiter.api.Test
4145

4246
class KotlinSerializerCodecProviderTest {
@@ -75,6 +79,41 @@ class KotlinSerializerCodecProviderTest {
7579
assertEquals(DataClassWithSimpleValues::class.java, codec.encoderClass)
7680
}
7781

82+
@Test
83+
fun testDataClassWithSimpleValuesFieldOrdering() {
84+
val codec = MongoClientSettings.getDefaultCodecRegistry().get(DataClassWithSimpleValues::class.java)
85+
val expected = DataClassWithSimpleValues('c', 0, 1, 22, 42L, 4.0f, 4.2, true, "String")
86+
87+
val numberLong = "\$numberLong"
88+
val actual =
89+
codec.decode(
90+
JsonReader(
91+
"""{"boolean": true, "byte": 0, "char": "c", "double": 4.2, "float": 4.0, "int": 22,
92+
|"long": {"$numberLong": "42"}, "short": 1, "string": "String"}"""
93+
.trimMargin()),
94+
DecoderContext.builder().build())
95+
96+
assertEquals(expected, actual)
97+
}
98+
99+
@Test
100+
fun testDataClassSealedFieldOrdering() {
101+
val codec = MongoClientSettings.getDefaultCodecRegistry().get(SealedInterface::class.java)
102+
103+
val objectId = ObjectId("111111111111111111111111")
104+
val oid = "\$oid"
105+
val expected = DataClassSealedInterface(objectId, "string")
106+
val actual =
107+
codec.decode(
108+
JsonReader(
109+
"""{"name": "string", "_id": {$oid: "${objectId.toHexString()}"},
110+
|"_t": "org.bson.codecs.kotlinx.samples.DataClassSealedInterface"}"""
111+
.trimMargin()),
112+
DecoderContext.builder().build())
113+
114+
assertEquals(expected, actual)
115+
}
116+
78117
@OptIn(ExperimentalSerializationApi::class)
79118
@Test
80119
fun shouldAllowOverridingOfSerializersModuleAndBsonConfigurationInConstructor() {

bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecTest.kt

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,23 @@ import org.bson.codecs.kotlinx.samples.DataClassWithSequence
8484
import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues
8585
import org.bson.codecs.kotlinx.samples.DataClassWithTriple
8686
import org.bson.codecs.kotlinx.samples.Key
87+
import org.bson.codecs.kotlinx.samples.SealedInterface
8788
import org.bson.codecs.kotlinx.samples.ValueClass
8889
import org.junit.jupiter.api.Test
8990
import org.junit.jupiter.api.assertThrows
9091

9192
@OptIn(ExperimentalSerializationApi::class)
93+
@Suppress("LargeClass")
9294
class KotlinSerializerCodecTest {
9395
private val numberLong = "\$numberLong"
96+
private val oid = "\$oid"
9497
private val emptyDocument = "{}"
9598
private val altConfiguration =
9699
BsonConfiguration(encodeDefaults = false, classDiscriminator = "_t", explicitNulls = true)
97100

98101
private val allBsonTypesJson =
99102
"""{
100-
| "id": {"${'$'}oid": "111111111111111111111111"},
103+
| "id": {"$oid": "111111111111111111111111"},
101104
| "arrayEmpty": [],
102105
| "arraySimple": [{"${'$'}numberInt": "1"}, {"${'$'}numberInt": "2"}, {"${'$'}numberInt": "3"}],
103106
| "arrayComplex": [{"a": {"${'$'}numberInt": "1"}}, {"a": {"${'$'}numberInt": "2"}}],
@@ -668,17 +671,49 @@ class KotlinSerializerCodecTest {
668671
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
669672
}
670673

671-
assertThrows<MissingFieldException>("Invalid complex types") {
672-
val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""")
673-
val codec = KotlinSerializerCodec.create<DataClassWithEmbedded>()
674-
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
675-
}
676-
677674
assertThrows<IllegalArgumentException>("Failing init") {
678675
val data = BsonDocument.parse("""{"id": "myId"}""")
679676
val codec = KotlinSerializerCodec.create<DataClassWithFailingInit>()
680677
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
681678
}
679+
680+
var exception =
681+
assertThrows<SerializationException>("Invalid complex types - document") {
682+
val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""")
683+
val codec = KotlinSerializerCodec.create<DataClassWithEmbedded>()
684+
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
685+
}
686+
assertEquals(
687+
"Invalid data for `org.bson.codecs.kotlinx.samples.DataClassEmbedded` " +
688+
"expected a bson document found: INT32",
689+
exception.message)
690+
691+
exception =
692+
assertThrows<SerializationException>("Invalid complex types - list") {
693+
val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""")
694+
val codec = KotlinSerializerCodec.create<DataClassListOfDataClasses>()
695+
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
696+
}
697+
assertEquals("Invalid data for `LIST` expected a bson array found: INT32", exception.message)
698+
699+
exception =
700+
assertThrows<SerializationException>("Invalid complex types - map") {
701+
val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""")
702+
val codec = KotlinSerializerCodec.create<DataClassMapOfDataClasses>()
703+
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
704+
}
705+
assertEquals("Invalid data for `MAP` expected a bson document found: INT32", exception.message)
706+
707+
exception =
708+
assertThrows<SerializationException>("Missing discriminator") {
709+
val data = BsonDocument.parse("""{"_id": {"$oid": "111111111111111111111111"}, "name": "string"}""")
710+
val codec = KotlinSerializerCodec.create<SealedInterface>()
711+
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
712+
}
713+
assertEquals(
714+
"Missing required discriminator field `_t` for polymorphic class: " +
715+
"`org.bson.codecs.kotlinx.samples.SealedInterface`.",
716+
exception.message)
682717
}
683718

684719
@Test

bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/samples/DataClasses.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,15 @@ data class DataClassOptionalBsonValues(
245245

246246
@Serializable @SerialName("C") data class DataClassSealedC(val c: String) : DataClassSealed()
247247

248+
@Serializable
249+
sealed interface SealedInterface {
250+
val name: String
251+
}
252+
253+
@Serializable
254+
data class DataClassSealedInterface(@Contextual @SerialName("_id") val id: ObjectId, override val name: String) :
255+
SealedInterface
256+
248257
@Serializable data class DataClassListOfSealed(val items: List<DataClassSealed>)
249258

250259
interface DataClassOpen

0 commit comments

Comments
 (0)