diff --git a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java index 606458b3382..4120dbdfb17 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java +++ b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.lang.Nullable; import org.bson.BsonBinary; import org.bson.BsonBinaryWriter; import org.bson.BsonBoolean; @@ -57,11 +58,17 @@ public class IdHoldingBsonWriter extends LevelCountingBsonWriter { private LevelCountingBsonWriter idBsonBinaryWriter; private BasicOutputBuffer outputBuffer; private String currentFieldName; + private final BsonValue fallbackId; private BsonValue id; private boolean idFieldIsAnArray = false; - public IdHoldingBsonWriter(final BsonWriter bsonWriter) { + /** + * @param fallbackId The "_id" field value to use if the top-level document written via this {@link BsonWriter} + * does not have "_id". If {@code null}, then a new {@link BsonObjectId} is generated instead. + */ + public IdHoldingBsonWriter(final BsonWriter bsonWriter, @Nullable final BsonObjectId fallbackId) { super(bsonWriter); + this.fallbackId = fallbackId; } @Override @@ -99,7 +106,7 @@ public void writeEndDocument() { } if (getCurrentLevel() == 0 && id == null) { - id = new BsonObjectId(); + id = fallbackId == null ? new BsonObjectId() : fallbackId; writeObjectId(ID_FIELD_NAME, id.asObjectId().getValue()); } super.writeEndDocument(); @@ -408,6 +415,15 @@ public void flush() { super.flush(); } + /** + * Returns either the value of the "_id" field from the top-level document written via this {@link BsonWriter}, + * provided that the document is not {@link RawBsonDocument}, + * or the generated {@link BsonObjectId}. + * If the document is {@link RawBsonDocument}, then returns {@code null}. + *

+ * {@linkplain #flush() Flushing} is not required before calling this method.

+ */ + @Nullable public BsonValue getId() { return id; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java index a71f7a940f0..7a0b835f428 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java @@ -23,6 +23,7 @@ import com.mongodb.internal.bulk.WriteRequestWithIndex; import org.bson.BsonDocument; import org.bson.BsonDocumentWrapper; +import org.bson.BsonObjectId; import org.bson.BsonValue; import org.bson.BsonWriter; import org.bson.codecs.BsonValueCodecProvider; @@ -191,10 +192,23 @@ public void encode(final BsonWriter writer, final WriteRequestWithIndex writeReq InsertRequest insertRequest = (InsertRequest) writeRequestWithIndex.getWriteRequest(); BsonDocument document = insertRequest.getDocument(); - IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(writer); - getCodec(document).encode(idHoldingBsonWriter, document, - EncoderContext.builder().isEncodingCollectibleDocument(true).build()); - insertedIds.put(writeRequestWithIndex.getIndex(), idHoldingBsonWriter.getId()); + BsonValue documentId = insertedIds.compute( + writeRequestWithIndex.getIndex(), + (writeRequestIndex, writeRequestDocumentId) -> { + IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter( + writer, + // Reuse `writeRequestDocumentId` if it may have been generated + // by `IdHoldingBsonWriter` in a previous attempt. + // If its type is not `BsonObjectId`, we know it could not have been generated. + writeRequestDocumentId instanceof BsonObjectId ? writeRequestDocumentId.asObjectId() : null); + getCodec(document).encode(idHoldingBsonWriter, document, + EncoderContext.builder().isEncodingCollectibleDocument(true).build()); + return idHoldingBsonWriter.getId(); + }); + if (documentId == null) { + // we must add an entry anyway because we rely on all the indexes being present + insertedIds.put(writeRequestWithIndex.getIndex(), null); + } } else if (writeRequestWithIndex.getType() == WriteRequest.Type.UPDATE || writeRequestWithIndex.getType() == WriteRequest.Type.REPLACE) { UpdateRequest update = (UpdateRequest) writeRequestWithIndex.getWriteRequest(); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy index 451545632d4..f603576ecfb 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy @@ -32,11 +32,12 @@ import static org.bson.BsonHelper.documentWithValuesOfEveryType import static org.bson.BsonHelper.getBsonValues class IdHoldingBsonWriterSpecification extends Specification { + private static final OBJECT_ID = new BsonObjectId() def 'should write all types'() { given: def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter) + def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId) def document = documentWithValuesOfEveryType() when: @@ -47,18 +48,25 @@ class IdHoldingBsonWriterSpecification extends Specification { !document.containsKey('_id') encodedDocument.containsKey('_id') idTrackingBsonWriter.getId() == encodedDocument.get('_id') + if (expectedIdNullIfMustBeGenerated != null) { + idTrackingBsonWriter.getId() == expectedIdNullIfMustBeGenerated + } when: encodedDocument.remove('_id') then: encodedDocument == document + + where: + fallbackId << [null, OBJECT_ID] + expectedIdNullIfMustBeGenerated << [null, OBJECT_ID] } def 'should support all types for _id value'() { given: def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter) + def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId) def document = new BsonDocument() document.put('_id', id) @@ -71,12 +79,15 @@ class IdHoldingBsonWriterSpecification extends Specification { idTrackingBsonWriter.getId() == id where: - id << getBsonValues() + [id, fallbackId] << [ + getBsonValues(), + [null, new BsonObjectId()] + ].combinations() } def 'serialize document with list of documents that contain an _id field'() { def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter) + def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId) def document = new BsonDocument('_id', new BsonObjectId()) .append('items', new BsonArray(Collections.singletonList(new BsonDocument('_id', new BsonObjectId())))) @@ -86,11 +97,14 @@ class IdHoldingBsonWriterSpecification extends Specification { then: encodedDocument == document + + where: + fallbackId << [null, new BsonObjectId()] } def 'serialize _id documents containing arrays'() { def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter) + def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId) BsonDocument document = BsonDocument.parse(json) when: @@ -102,7 +116,8 @@ class IdHoldingBsonWriterSpecification extends Specification { encodedDocument == document where: - json << ['{"_id": {"a": []}, "b": 123}', + [json, fallbackId] << [ + ['{"_id": {"a": []}, "b": 123}', '{"_id": {"a": [1, 2]}, "b": 123}', '{"_id": {"a": [[[[1]]]]}, "b": 123}', '{"_id": {"a": [{"a": [1, 2]}]}, "b": 123}', @@ -112,7 +127,9 @@ class IdHoldingBsonWriterSpecification extends Specification { '{"_id": [1, 2], "b": 123}', '{"_id": [[1], [[2]]], "b": 123}', '{"_id": [{"a": 1}], "b": 123}', - '{"_id": [{"a": [{"b": 123}]}]}'] + '{"_id": [{"a": [{"b": 123}]}]}'], + [null, new BsonObjectId()] + ].combinations() } private static BsonDocument getEncodedDocument(BsonOutput buffer) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java index b8d94cfe067..5d3907bb210 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java @@ -19,21 +19,35 @@ import com.mongodb.MongoBulkWriteException; import com.mongodb.MongoWriteConcernException; import com.mongodb.MongoWriteException; +import com.mongodb.ServerAddress; import com.mongodb.client.model.CreateCollectionOptions; import com.mongodb.client.model.Filters; import com.mongodb.client.model.ValidationOptions; +import com.mongodb.event.CommandListener; +import com.mongodb.event.CommandStartedEvent; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; +import org.bson.BsonValue; import org.bson.Document; +import org.bson.codecs.pojo.PojoCodecProvider; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet; import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; +import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; import static java.lang.String.format; import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -114,6 +128,54 @@ public void testWriteErrorDetailsIsPropagated() { } } + /** + * This test is not from the specification. + */ + @Test + @SuppressWarnings("try") + void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException { + assumeTrue(isDiscoverableReplicaSet()); + ServerAddress primaryServerAddress = Fixture.getPrimary(); + CompletableFuture futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>(); + CompletableFuture futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>(); + CommandListener commandListener = new CommandListener() { + @Override + public void commandStarted(final CommandStartedEvent event) { + if (event.getCommandName().equals("insert")) { + BsonValue generatedId = event.getCommand().getArray("documents").get(0).asDocument().get("_id"); + if (!futureIdGeneratedByFirstInsertAttempt.isDone()) { + futureIdGeneratedByFirstInsertAttempt.complete(generatedId); + } else { + futureIdGeneratedBySecondInsertAttempt.complete(generatedId); + } + } + } + }; + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(1))) + .append("data", new BsonDocument() + .append("failCommands", new BsonArray(singletonList(new BsonString("insert")))) + .append("errorLabels", new BsonArray(singletonList(new BsonString("RetryableWriteError")))) + .append("writeConcernError", new BsonDocument("code", new BsonInt32(91)) + .append("errmsg", new BsonString("Replication is being shut down")))); + try (MongoClient client = MongoClients.create(getMongoClientSettingsBuilder() + .retryWrites(true) + .addCommandListener(commandListener) + .applyToServerSettings(builder -> builder.heartbeatFrequency(50, TimeUnit.MILLISECONDS)) + .build()); + FailPoint ignored = FailPoint.enable(failPointDocument, primaryServerAddress)) { + MongoCollection coll = client.getDatabase(database.getName()) + .getCollection(collection.getNamespace().getCollectionName(), MyDocument.class) + .withCodecRegistry(fromRegistries( + getDefaultCodecRegistry(), + fromProviders(PojoCodecProvider.builder().automatic(true).build()))); + BsonValue insertedId = coll.insertOne(new MyDocument()).getInsertedId(); + BsonValue idGeneratedByFirstInsertAttempt = futureIdGeneratedByFirstInsertAttempt.get(); + assertEquals(idGeneratedByFirstInsertAttempt, insertedId); + assertEquals(idGeneratedByFirstInsertAttempt, futureIdGeneratedBySecondInsertAttempt.get()); + } + } + private void setFailPoint() { failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) @@ -130,4 +192,15 @@ private void setFailPoint() { private void disableFailPoint() { getCollectionHelper().runAdminCommand(failPointDocument.append("mode", new BsonString("off"))); } + + public static final class MyDocument { + private int v; + + public MyDocument() { + } + + public int getV() { + return v; + } + } }