38
38
import java .lang .reflect .InvocationTargetException ;
39
39
import java .lang .reflect .ParameterizedType ;
40
40
import java .lang .reflect .RecordComponent ;
41
+ import java .lang .reflect .Type ;
42
+ import java .lang .reflect .TypeVariable ;
41
43
import java .util .ArrayList ;
42
44
import java .util .Arrays ;
43
45
import java .util .List ;
48
50
import static java .lang .String .format ;
49
51
import static org .bson .assertions .Assertions .notNull ;
50
52
51
- final class RecordCodec <T extends Record > implements Codec <T > {
53
+ final class RecordCodec <T extends Record > implements Codec <T >, Parameterizable {
52
54
private static final Logger LOGGER = Loggers .getLogger ("RecordCodec" );
53
55
private final Class <T > clazz ;
56
+ private final boolean requiresParameterization ;
54
57
private final Constructor <?> canonicalConstructor ;
55
58
private final List <ComponentModel > componentModels ;
56
59
private final ComponentModel componentModelForId ;
@@ -62,10 +65,11 @@ private static final class ComponentModel {
62
65
private final int index ;
63
66
private final String fieldName ;
64
67
65
- private ComponentModel (final RecordComponent component , final CodecRegistry codecRegistry , final int index ) {
68
+ private ComponentModel (final List <Type > typeParameters , final RecordComponent component , final CodecRegistry codecRegistry ,
69
+ final int index ) {
66
70
validateAnnotations (component , index );
67
71
this .component = component ;
68
- this .codec = computeCodec (component , codecRegistry );
72
+ this .codec = computeCodec (typeParameters , component , codecRegistry );
69
73
this .index = index ;
70
74
this .fieldName = computeFieldName (component );
71
75
}
@@ -83,11 +87,13 @@ Object getValue(final Record record) throws InvocationTargetException, IllegalAc
83
87
}
84
88
85
89
@ SuppressWarnings ("deprecation" )
86
- private static Codec <?> computeCodec (final RecordComponent component , final CodecRegistry codecRegistry ) {
87
- var codec = codecRegistry .get (toWrapper (component .getType ()));
90
+ private static Codec <?> computeCodec (final List <Type > typeParameters , final RecordComponent component ,
91
+ final CodecRegistry codecRegistry ) {
92
+ var codec = codecRegistry .get (toWrapper (resolveComponentType (typeParameters , component )));
88
93
if (codec instanceof Parameterizable parameterizableCodec
89
94
&& component .getGenericType () instanceof ParameterizedType parameterizedType ) {
90
- codec = parameterizableCodec .parameterize (codecRegistry , Arrays .asList (parameterizedType .getActualTypeArguments ()));
95
+ codec = parameterizableCodec .parameterize (codecRegistry ,
96
+ resolveActualTypeArguments (typeParameters , component .getDeclaringRecord (), parameterizedType ));
91
97
}
92
98
BsonType bsonRepresentationType = null ;
93
99
@@ -109,6 +115,36 @@ private static Codec<?> computeCodec(final RecordComponent component, final Code
109
115
return codec ;
110
116
}
111
117
118
+ private static Class <?> resolveComponentType (final List <Type > typeParameters , final RecordComponent component ) {
119
+ Type resolvedType = resolveType (component .getGenericType (), typeParameters , component .getDeclaringRecord ());
120
+ return resolvedType instanceof Class <?> clazz ? clazz : component .getType ();
121
+ }
122
+
123
+ private static List <Type > resolveActualTypeArguments (final List <Type > typeParameters , final Class <?> recordClass ,
124
+ final ParameterizedType parameterizedType ) {
125
+ return Arrays .stream (parameterizedType .getActualTypeArguments ())
126
+ .map (type -> resolveType (type , typeParameters , recordClass ))
127
+ .toList ();
128
+ }
129
+
130
+ private static Type resolveType (final Type type , final List <Type > typeParameters , final Class <?> recordClass ) {
131
+ return type instanceof TypeVariable <?> typeVariable
132
+ ? typeParameters .get (getIndexOfTypeParameter (typeVariable .getName (), recordClass ))
133
+ : type ;
134
+ }
135
+
136
+ // Get
137
+ private static int getIndexOfTypeParameter (final String typeParameterName , final Class <?> recordClass ) {
138
+ var typeParameters = recordClass .getTypeParameters ();
139
+ for (int i = 0 ; i < typeParameters .length ; i ++) {
140
+ if (typeParameters [i ].getName ().equals (typeParameterName )) {
141
+ return i ;
142
+ }
143
+ }
144
+ throw new CodecConfigurationException (String .format ("Could not find type parameter on record %s with name %s" ,
145
+ recordClass .getName (), typeParameterName ));
146
+ }
147
+
112
148
@ SuppressWarnings ("deprecation" )
113
149
private static String computeFieldName (final RecordComponent component ) {
114
150
if (component .isAnnotationPresent (BsonId .class )) {
@@ -218,16 +254,47 @@ private static <T extends Annotation> void validateAnnotationOnlyOnField(final R
218
254
219
255
RecordCodec (final Class <T > clazz , final CodecRegistry codecRegistry ) {
220
256
this .clazz = notNull ("class" , clazz );
257
+ if (clazz .getTypeParameters ().length > 0 ) {
258
+ requiresParameterization = true ;
259
+ canonicalConstructor = null ;
260
+ componentModels = null ;
261
+ fieldNameToComponentModel = null ;
262
+ componentModelForId = null ;
263
+ } else {
264
+ requiresParameterization = false ;
265
+ canonicalConstructor = notNull ("canonicalConstructor" , getCanonicalConstructor (clazz ));
266
+ componentModels = getComponentModels (clazz , codecRegistry , List .of ());
267
+ fieldNameToComponentModel = componentModels .stream ()
268
+ .collect (Collectors .toMap (ComponentModel ::getFieldName , Function .identity ()));
269
+ componentModelForId = getComponentModelForId (clazz , componentModels );
270
+ }
271
+ }
272
+
273
+ RecordCodec (final Class <T > clazz , final CodecRegistry codecRegistry , final List <Type > types ) {
274
+ if (types .size () != clazz .getTypeParameters ().length ) {
275
+ throw new CodecConfigurationException ("Unexpected number of type parameters for record class " + clazz );
276
+ }
277
+ this .clazz = notNull ("class" , clazz );
278
+ requiresParameterization = false ;
221
279
canonicalConstructor = notNull ("canonicalConstructor" , getCanonicalConstructor (clazz ));
222
- componentModels = getComponentModels (clazz , codecRegistry );
280
+ componentModels = getComponentModels (clazz , codecRegistry , types );
223
281
fieldNameToComponentModel = componentModels .stream ()
224
282
.collect (Collectors .toMap (ComponentModel ::getFieldName , Function .identity ()));
225
283
componentModelForId = getComponentModelForId (clazz , componentModels );
226
284
}
227
285
286
+ @ Override
287
+ public Codec <?> parameterize (final CodecRegistry codecRegistry , final List <Type > types ) {
288
+ return new RecordCodec <>(clazz , codecRegistry , types );
289
+ }
290
+
228
291
@ SuppressWarnings ("unchecked" )
229
292
@ Override
230
293
public T decode (final BsonReader reader , final DecoderContext decoderContext ) {
294
+ if (requiresParameterization ) {
295
+ throw new CodecConfigurationException ("Can not decode to a record with type parameters that has not been parameterized" );
296
+ }
297
+
231
298
reader .readStartDocument ();
232
299
233
300
Object [] constructorArguments = new Object [componentModels .size ()];
@@ -254,6 +321,10 @@ public T decode(final BsonReader reader, final DecoderContext decoderContext) {
254
321
255
322
@ Override
256
323
public void encode (final BsonWriter writer , final T record , final EncoderContext encoderContext ) {
324
+ if (requiresParameterization ) {
325
+ throw new CodecConfigurationException ("Can not decode to a record with type parameters that has not been parameterized" );
326
+ }
327
+
257
328
writer .writeStartDocument ();
258
329
if (componentModelForId != null ) {
259
330
writeComponent (writer , record , componentModelForId );
@@ -287,11 +358,12 @@ private void writeComponent(final BsonWriter writer, final T record, final Compo
287
358
}
288
359
}
289
360
290
- private static <T > List <ComponentModel > getComponentModels (final Class <T > clazz , final CodecRegistry codecRegistry ) {
361
+ private static <T > List <ComponentModel > getComponentModels (final Class <T > clazz , final CodecRegistry codecRegistry ,
362
+ final List <Type > typeParameters ) {
291
363
var recordComponents = clazz .getRecordComponents ();
292
364
var componentModels = new ArrayList <ComponentModel >(recordComponents .length );
293
365
for (int i = 0 ; i < recordComponents .length ; i ++) {
294
- componentModels .add (new ComponentModel (recordComponents [i ], codecRegistry , i ));
366
+ componentModels .add (new ComponentModel (typeParameters , recordComponents [i ], codecRegistry , i ));
295
367
}
296
368
return componentModels ;
297
369
}
0 commit comments