@@ -31,6 +31,7 @@ import kotlinx.serialization.modules.SerializersModule
31
31
import org.bson.AbstractBsonReader
32
32
import org.bson.BsonInvalidOperationException
33
33
import org.bson.BsonReader
34
+ import org.bson.BsonReaderMark
34
35
import org.bson.BsonType
35
36
import org.bson.BsonValue
36
37
import org.bson.codecs.BsonValueCodec
@@ -68,6 +69,20 @@ internal open class DefaultBsonDecoder(
68
69
val validKeyKinds = setOf (PrimitiveKind .STRING , PrimitiveKind .CHAR , SerialKind .ENUM )
69
70
val bsonValueCodec = BsonValueCodec ()
70
71
const val UNKNOWN_INDEX = - 10
72
+ fun validateCurrentBsonType (
73
+ reader : AbstractBsonReader ,
74
+ expectedType : BsonType ,
75
+ descriptor : SerialDescriptor ,
76
+ actualType : (descriptor: SerialDescriptor ) -> String = { it.kind.toString() }
77
+ ) {
78
+ reader.currentBsonType?.let {
79
+ if (it != expectedType) {
80
+ throw SerializationException (
81
+ " Invalid data for `${actualType(descriptor)} ` expected a bson " +
82
+ " ${expectedType.name.lowercase()} found: ${reader.currentBsonType} " )
83
+ }
84
+ }
85
+ }
71
86
}
72
87
73
88
private fun initElementMetadata (descriptor : SerialDescriptor ) {
@@ -119,29 +134,14 @@ internal open class DefaultBsonDecoder(
119
134
120
135
@Suppress(" ReturnCount" )
121
136
override fun beginStructure (descriptor : SerialDescriptor ): CompositeDecoder {
122
- when (descriptor.kind) {
123
- is StructureKind .LIST -> {
124
- reader.readStartArray()
125
- return BsonArrayDecoder (reader, serializersModule, configuration)
126
- }
127
- is PolymorphicKind -> {
128
- reader.readStartDocument()
129
- return PolymorphicDecoder (reader, serializersModule, configuration)
130
- }
137
+ return when (descriptor.kind) {
138
+ is StructureKind .LIST -> BsonArrayDecoder (descriptor, reader, serializersModule, configuration)
139
+ is PolymorphicKind -> PolymorphicDecoder (descriptor, reader, serializersModule, configuration)
131
140
is StructureKind .CLASS ,
132
- StructureKind .OBJECT -> {
133
- val current = reader.currentBsonType
134
- if (current == null || current == BsonType .DOCUMENT ) {
135
- reader.readStartDocument()
136
- }
137
- }
138
- is StructureKind .MAP -> {
139
- reader.readStartDocument()
140
- return BsonDocumentDecoder (reader, serializersModule, configuration)
141
- }
141
+ StructureKind .OBJECT -> BsonDocumentDecoder (descriptor, reader, serializersModule, configuration)
142
+ is StructureKind .MAP -> MapDecoder (descriptor, reader, serializersModule, configuration)
142
143
else -> throw SerializationException (" Primitives are not supported at top-level" )
143
144
}
144
- return DefaultBsonDecoder (reader, serializersModule, configuration)
145
145
}
146
146
147
147
override fun endStructure (descriptor : SerialDescriptor ) {
@@ -194,10 +194,17 @@ internal open class DefaultBsonDecoder(
194
194
195
195
@OptIn(ExperimentalSerializationApi ::class )
196
196
private class BsonArrayDecoder (
197
+ descriptor : SerialDescriptor ,
197
198
reader : AbstractBsonReader ,
198
199
serializersModule : SerializersModule ,
199
200
configuration : BsonConfiguration
200
201
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
202
+
203
+ init {
204
+ validateCurrentBsonType(reader, BsonType .ARRAY , descriptor)
205
+ reader.readStartArray()
206
+ }
207
+
201
208
private var index = 0
202
209
override fun decodeElementIndex (descriptor : SerialDescriptor ): Int {
203
210
val nextType = reader.readBsonType()
@@ -208,18 +215,46 @@ private class BsonArrayDecoder(
208
215
209
216
@OptIn(ExperimentalSerializationApi ::class )
210
217
private class PolymorphicDecoder (
218
+ descriptor : SerialDescriptor ,
211
219
reader : AbstractBsonReader ,
212
220
serializersModule : SerializersModule ,
213
221
configuration : BsonConfiguration
214
222
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
215
223
private var index = 0
224
+ private var mark: BsonReaderMark ?
216
225
217
- override fun <T > decodeSerializableValue (deserializer : DeserializationStrategy <T >): T =
218
- deserializer.deserialize(DefaultBsonDecoder (reader, serializersModule, configuration))
226
+ init {
227
+ mark = reader.mark
228
+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor) { it.serialName }
229
+ reader.readStartDocument()
230
+ }
231
+
232
+ override fun <T > decodeSerializableValue (deserializer : DeserializationStrategy <T >): T {
233
+ mark?.let {
234
+ it.reset()
235
+ mark = null
236
+ }
237
+ return deserializer.deserialize(DefaultBsonDecoder (reader, serializersModule, configuration))
238
+ }
219
239
220
240
override fun decodeElementIndex (descriptor : SerialDescriptor ): Int {
241
+ var found = false
221
242
return when (index) {
222
- 0 -> index++
243
+ 0 -> {
244
+ while (reader.readBsonType() != BsonType .END_OF_DOCUMENT ) {
245
+ if (reader.readName() == configuration.classDiscriminator) {
246
+ found = true
247
+ break
248
+ }
249
+ reader.skipValue()
250
+ }
251
+ if (! found) {
252
+ throw SerializationException (
253
+ " Missing required discriminator field `${configuration.classDiscriminator} ` " +
254
+ " for polymorphic class: `${descriptor.serialName} `." )
255
+ }
256
+ index++
257
+ }
223
258
1 -> index++
224
259
else -> DECODE_DONE
225
260
}
@@ -228,6 +263,20 @@ private class PolymorphicDecoder(
228
263
229
264
@OptIn(ExperimentalSerializationApi ::class )
230
265
private class BsonDocumentDecoder (
266
+ descriptor : SerialDescriptor ,
267
+ reader : AbstractBsonReader ,
268
+ serializersModule : SerializersModule ,
269
+ configuration : BsonConfiguration
270
+ ) : DefaultBsonDecoder(reader, serializersModule, configuration) {
271
+ init {
272
+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor) { it.serialName }
273
+ reader.readStartDocument()
274
+ }
275
+ }
276
+
277
+ @OptIn(ExperimentalSerializationApi ::class )
278
+ private class MapDecoder (
279
+ descriptor : SerialDescriptor ,
231
280
reader : AbstractBsonReader ,
232
281
serializersModule : SerializersModule ,
233
282
configuration : BsonConfiguration
@@ -236,6 +285,11 @@ private class BsonDocumentDecoder(
236
285
private var index = 0
237
286
private var isKey = false
238
287
288
+ init {
289
+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor)
290
+ reader.readStartDocument()
291
+ }
292
+
239
293
override fun decodeString (): String {
240
294
return if (isKey) {
241
295
reader.readName()
0 commit comments