diff --git a/bson/src/main/org/bson/codecs/pojo/LazyPropertyModelCodec.java b/bson/src/main/org/bson/codecs/pojo/LazyPropertyModelCodec.java index a502c337bd8..24537ce1d8e 100644 --- a/bson/src/main/org/bson/codecs/pojo/LazyPropertyModelCodec.java +++ b/bson/src/main/org/bson/codecs/pojo/LazyPropertyModelCodec.java @@ -163,19 +163,44 @@ private PropertyModel getSpecializedPropertyModel(final PropertyModel static final class NeedSpecializationCodec extends PojoCodec { private final ClassModel classModel; private final DiscriminatorLookup discriminatorLookup; + private final CodecRegistry codecRegistry; - NeedSpecializationCodec(final ClassModel classModel, final DiscriminatorLookup discriminatorLookup) { + NeedSpecializationCodec(final ClassModel classModel, final DiscriminatorLookup discriminatorLookup, final CodecRegistry codecRegistry) { this.classModel = classModel; this.discriminatorLookup = discriminatorLookup; + this.codecRegistry = codecRegistry; } @Override - public T decode(final BsonReader reader, final DecoderContext decoderContext) { - throw exception(); + public void encode(final BsonWriter writer, final T value, final EncoderContext encoderContext) { + if (value.getClass().equals(classModel.getType())) { + throw exception(); + } + tryEncode(codecRegistry.get(value.getClass()), writer, value, encoderContext); } @Override - public void encode(final BsonWriter writer, final T value, final EncoderContext encoderContext) { + public T decode(final BsonReader reader, final DecoderContext decoderContext) { + return tryDecode(reader, decoderContext); + } + + @SuppressWarnings("unchecked") + private void tryEncode(final Codec codec, final BsonWriter writer, final T value, final EncoderContext encoderContext) { + try { + codec.encode(writer, (A) value, encoderContext); + } catch (Exception e) { + throw exception(); + } + } + + @SuppressWarnings("unchecked") + public T tryDecode(final BsonReader reader, final DecoderContext decoderContext) { + Codec codec = PojoCodecImpl.getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.getDiscriminatorKey(), + codecRegistry, discriminatorLookup, null, classModel.getName()); + if (codec != null) { + return codec.decode(reader, decoderContext); + } + throw exception(); } diff --git a/bson/src/main/org/bson/codecs/pojo/PojoCodecImpl.java b/bson/src/main/org/bson/codecs/pojo/PojoCodecImpl.java index bccadfb3e0d..96853000198 100644 --- a/bson/src/main/org/bson/codecs/pojo/PojoCodecImpl.java +++ b/bson/src/main/org/bson/codecs/pojo/PojoCodecImpl.java @@ -101,7 +101,8 @@ public T decode(final BsonReader reader, final DecoderContext decoderContext) { return instanceCreator.getInstance(); } else { return getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.getDiscriminatorKey(), registry, - discriminatorLookup, this).decode(reader, DecoderContext.builder().checkedDiscriminator(true).build()); + discriminatorLookup, this, classModel.getName()) + .decode(reader, DecoderContext.builder().checkedDiscriminator(true).build()); } } @@ -275,10 +276,11 @@ private boolean areEquivalentTypes(final Class t1, final Class t2) } @SuppressWarnings("unchecked") - private Codec getCodecFromDocument(final BsonReader reader, final boolean useDiscriminator, final String discriminatorKey, - final CodecRegistry registry, final DiscriminatorLookup discriminatorLookup, - final Codec defaultCodec) { - Codec codec = defaultCodec; + @Nullable + static Codec getCodecFromDocument(final BsonReader reader, final boolean useDiscriminator, final String discriminatorKey, + final CodecRegistry registry, final DiscriminatorLookup discriminatorLookup, @Nullable final Codec defaultCodec, + final String simpleClassName) { + Codec codec = defaultCodec; if (useDiscriminator) { BsonReaderMark mark = reader.getMark(); reader.readStartDocument(); @@ -289,12 +291,12 @@ private Codec getCodecFromDocument(final BsonReader reader, final boolean use discriminatorKeyFound = true; try { Class discriminatorClass = discriminatorLookup.lookup(reader.readString()); - if (!codec.getEncoderClass().equals(discriminatorClass)) { - codec = (Codec) registry.get(discriminatorClass); + if (codec == null || !codec.getEncoderClass().equals(discriminatorClass)) { + codec = (Codec) registry.get(discriminatorClass); } } catch (Exception e) { throw new CodecConfigurationException(format("Failed to decode '%s'. Decoding errored with: %s", - classModel.getName(), e.getMessage()), e); + simpleClassName, e.getMessage()), e); } } else { reader.skipValue(); diff --git a/bson/src/main/org/bson/codecs/pojo/PojoCodecProvider.java b/bson/src/main/org/bson/codecs/pojo/PojoCodecProvider.java index 6a3e8bfc836..b62364b1b4b 100644 --- a/bson/src/main/org/bson/codecs/pojo/PojoCodecProvider.java +++ b/bson/src/main/org/bson/codecs/pojo/PojoCodecProvider.java @@ -97,7 +97,7 @@ private static PojoCodec createCodec(final ClassModel classModel, fina final List propertyCodecProviders, final DiscriminatorLookup discriminatorLookup) { return shouldSpecialize(classModel) ? new PojoCodecImpl<>(classModel, codecRegistry, propertyCodecProviders, discriminatorLookup) - : new LazyPropertyModelCodec.NeedSpecializationCodec<>(classModel, discriminatorLookup); + : new LazyPropertyModelCodec.NeedSpecializationCodec<>(classModel, discriminatorLookup, codecRegistry); } /** diff --git a/bson/src/test/unit/org/bson/codecs/pojo/PojoCustomTest.java b/bson/src/test/unit/org/bson/codecs/pojo/PojoCustomTest.java index acb63b04f06..cf8cef50282 100644 --- a/bson/src/test/unit/org/bson/codecs/pojo/PojoCustomTest.java +++ b/bson/src/test/unit/org/bson/codecs/pojo/PojoCustomTest.java @@ -38,11 +38,14 @@ import org.bson.codecs.pojo.entities.BsonRepresentationUnsupportedString; import org.bson.codecs.pojo.entities.ConcreteAndNestedAbstractInterfaceModel; import org.bson.codecs.pojo.entities.ConcreteCollectionsModel; +import org.bson.codecs.pojo.entities.ConcreteModel; +import org.bson.codecs.pojo.entities.ConcreteField; import org.bson.codecs.pojo.entities.ConcreteStandAloneAbstractInterfaceModel; import org.bson.codecs.pojo.entities.ConstructorNotPublicModel; import org.bson.codecs.pojo.entities.ConventionModel; import org.bson.codecs.pojo.entities.ConverterModel; import org.bson.codecs.pojo.entities.CustomPropertyCodecOptionalModel; +import org.bson.codecs.pojo.entities.GenericBaseModel; import org.bson.codecs.pojo.entities.GenericHolderModel; import org.bson.codecs.pojo.entities.GenericTreeModel; import org.bson.codecs.pojo.entities.InterfaceBasedModel; @@ -545,6 +548,17 @@ public void testInvalidDiscriminatorInNestedModel() { + "'simple': {'_t': 'FakeModel', 'integerField': 42, 'stringField': 'myString'}}")); } + @Test + public void testGenericBaseClass() { + CodecRegistry registry = fromProviders(new ValueCodecProvider(), PojoCodecProvider.builder().automatic(true).build()); + + ConcreteModel model = new ConcreteModel(new ConcreteField("name1")); + + String json = "{\"_t\": \"org.bson.codecs.pojo.entities.ConcreteModel\", \"field\": {\"name\": \"name1\"}}"; + roundTrip(PojoCodecProvider.builder().automatic(true), GenericBaseModel.class, model, json); + } + + @Test public void testCannotEncodeUnspecializedClasses() { CodecRegistry registry = fromProviders(getPojoCodecProviderBuilder(GenericTreeModel.class).build()); @@ -553,7 +567,7 @@ public void testCannotEncodeUnspecializedClasses() { } @Test - public void testCannotDecodeUnspecializedClasses() { + public void testCannotDecodeUnspecializedClassesWithoutADiscriminator() { assertThrows(CodecConfigurationException.class, () -> decodingShouldFail(getCodec(GenericTreeModel.class), "{'field1': 'top', 'field2': 1, " diff --git a/bson/src/test/unit/org/bson/codecs/pojo/PojoTestCase.java b/bson/src/test/unit/org/bson/codecs/pojo/PojoTestCase.java index 5b5209435cb..eb380bb7986 100644 --- a/bson/src/test/unit/org/bson/codecs/pojo/PojoTestCase.java +++ b/bson/src/test/unit/org/bson/codecs/pojo/PojoTestCase.java @@ -90,8 +90,12 @@ void roundTrip(final T value, final String json) { } void roundTrip(final PojoCodecProvider.Builder builder, final T value, final String json) { - encodesTo(getCodecRegistry(builder), value, json); - decodesTo(getCodecRegistry(builder), json, value); + roundTrip(builder, value.getClass(), value, json); + } + + void roundTrip(final PojoCodecProvider.Builder builder, final Class clazz, final T value, final String json) { + encodesTo(getCodecRegistry(builder), clazz, value, json); + decodesTo(getCodecRegistry(builder), clazz, json, value); } void threadedRoundTrip(final PojoCodecProvider.Builder builder, final T value, final String json) { @@ -109,21 +113,30 @@ void roundTrip(final CodecRegistry registry, final T value, final String jso decodesTo(registry, json, value); } + void roundTrip(final CodecRegistry registry, final Class clazz, final T value, final String json) { + encodesTo(registry, clazz, value, json); + decodesTo(registry, clazz, json, value); + } + void encodesTo(final PojoCodecProvider.Builder builder, final T value, final String json) { encodesTo(builder, value, json, false); } void encodesTo(final PojoCodecProvider.Builder builder, final T value, final String json, final boolean collectible) { - encodesTo(getCodecRegistry(builder), value, json, collectible); + encodesTo(getCodecRegistry(builder), value.getClass(), value, json, collectible); } void encodesTo(final CodecRegistry registry, final T value, final String json) { - encodesTo(registry, value, json, false); + encodesTo(registry, value.getClass(), value, json, false); + } + + void encodesTo(final CodecRegistry registry, final Class clazz, final T value, final String json) { + encodesTo(registry, clazz, value, json, false); } @SuppressWarnings("unchecked") - void encodesTo(final CodecRegistry registry, final T value, final String json, final boolean collectible) { - Codec codec = (Codec) registry.get(value.getClass()); + void encodesTo(final CodecRegistry registry, final Class clazz, final T value, final String json, final boolean collectible) { + Codec codec = (Codec) registry.get(clazz); encodesTo(codec, value, json, collectible); } @@ -144,7 +157,12 @@ void decodesTo(final PojoCodecProvider.Builder builder, final String json, f @SuppressWarnings("unchecked") void decodesTo(final CodecRegistry registry, final String json, final T expected) { - Codec codec = (Codec) registry.get(expected.getClass()); + decodesTo(registry, expected.getClass(), json, expected); + } + + @SuppressWarnings("unchecked") + void decodesTo(final CodecRegistry registry, final Class clazz, final String json, final T expected) { + Codec codec = (Codec) registry.get(clazz); decodesTo(codec, json, expected); } diff --git a/bson/src/test/unit/org/bson/codecs/pojo/entities/BaseField.java b/bson/src/test/unit/org/bson/codecs/pojo/entities/BaseField.java new file mode 100644 index 00000000000..4393c5f2d7f --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/pojo/entities/BaseField.java @@ -0,0 +1,55 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs.pojo.entities; + +import java.util.Objects; + +public abstract class BaseField { + private String name; + + public BaseField(final String name) { + this.name = name; + } + + protected BaseField() { + } + + public String getName() { + return name; + } + + public void setName(final String name) { + this.name = name; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BaseField baseField = (BaseField) o; + return Objects.equals(name, baseField.name); + } + + @Override + public int hashCode() { + return Objects.hashCode(name); + } +} diff --git a/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteField.java b/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteField.java new file mode 100644 index 00000000000..6fb06a70de9 --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteField.java @@ -0,0 +1,27 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs.pojo.entities; + +public class ConcreteField extends BaseField { + + public ConcreteField() { + } + + public ConcreteField(final String name) { + super(name); + } +} diff --git a/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteModel.java b/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteModel.java new file mode 100644 index 00000000000..cd406fa1392 --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/pojo/entities/ConcreteModel.java @@ -0,0 +1,27 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs.pojo.entities; + +public class ConcreteModel extends GenericBaseModel { + + public ConcreteModel() { + } + + public ConcreteModel(final ConcreteField field) { + super(field); + } +} diff --git a/bson/src/test/unit/org/bson/codecs/pojo/entities/GenericBaseModel.java b/bson/src/test/unit/org/bson/codecs/pojo/entities/GenericBaseModel.java new file mode 100644 index 00000000000..5164f9703e5 --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/pojo/entities/GenericBaseModel.java @@ -0,0 +1,59 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs.pojo.entities; + +import org.bson.codecs.pojo.annotations.BsonDiscriminator; + +import java.util.Objects; + +@BsonDiscriminator() +public class GenericBaseModel { + + private T field; + + public GenericBaseModel(final T field) { + this.field = field; + } + + public GenericBaseModel() { + } + + public T getField() { + return field; + } + + public void setField(final T field) { + this.field = field; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GenericBaseModel that = (GenericBaseModel) o; + return Objects.equals(field, that.field); + } + + @Override + public int hashCode() { + return Objects.hashCode(field); + } +}