From a3c74be65c9bda625cf88e384e4706d55b541f8c Mon Sep 17 00:00:00 2001
From: Valentin Kovalenko <valentin.kovalenko@mongodb.com>
Date: Wed, 14 Aug 2024 21:23:02 -0600
Subject: [PATCH 1/4] MixedBulkWriteOperation should generate inserted document
 IDs at most once per batch

JAVA-5572
---
 .../connection/IdHoldingBsonWriter.java       | 20 ++++++++++--
 .../connection/SplittablePayload.java         | 22 ++++++++++---
 .../IdHoldingBsonWriterSpecification.groovy   | 31 ++++++++++++++-----
 3 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java
index 606458b3382..4120dbdfb17 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java
@@ -16,6 +16,7 @@
 
 package com.mongodb.internal.connection;
 
+import com.mongodb.lang.Nullable;
 import org.bson.BsonBinary;
 import org.bson.BsonBinaryWriter;
 import org.bson.BsonBoolean;
@@ -57,11 +58,17 @@ public class IdHoldingBsonWriter extends LevelCountingBsonWriter {
     private LevelCountingBsonWriter idBsonBinaryWriter;
     private BasicOutputBuffer outputBuffer;
     private String currentFieldName;
+    private final BsonValue fallbackId;
     private BsonValue id;
     private boolean idFieldIsAnArray = false;
 
-    public IdHoldingBsonWriter(final BsonWriter bsonWriter) {
+    /**
+     * @param fallbackId The "_id" field value to use if the top-level document written via this {@link BsonWriter}
+     * does not have "_id". If {@code null}, then a new {@link BsonObjectId} is generated instead.
+     */
+    public IdHoldingBsonWriter(final BsonWriter bsonWriter, @Nullable final BsonObjectId fallbackId) {
         super(bsonWriter);
+        this.fallbackId = fallbackId;
     }
 
     @Override
@@ -99,7 +106,7 @@ public void writeEndDocument() {
         }
 
         if (getCurrentLevel() == 0 && id == null) {
-            id = new BsonObjectId();
+            id = fallbackId == null ? new BsonObjectId() : fallbackId;
             writeObjectId(ID_FIELD_NAME, id.asObjectId().getValue());
         }
         super.writeEndDocument();
@@ -408,6 +415,15 @@ public void flush() {
         super.flush();
     }
 
+    /**
+     * Returns either the value of the "_id" field from the top-level document written via this {@link BsonWriter},
+     * provided that the document is not {@link RawBsonDocument},
+     * or the generated {@link BsonObjectId}.
+     * If the document is {@link RawBsonDocument}, then returns {@code null}.
+     * <p>
+     * {@linkplain #flush() Flushing} is not required before calling this method.</p>
+     */
+    @Nullable
     public BsonValue getId() {
         return id;
     }
diff --git a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java
index a71f7a940f0..7a0b835f428 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java
@@ -23,6 +23,7 @@
 import com.mongodb.internal.bulk.WriteRequestWithIndex;
 import org.bson.BsonDocument;
 import org.bson.BsonDocumentWrapper;
+import org.bson.BsonObjectId;
 import org.bson.BsonValue;
 import org.bson.BsonWriter;
 import org.bson.codecs.BsonValueCodecProvider;
@@ -191,10 +192,23 @@ public void encode(final BsonWriter writer, final WriteRequestWithIndex writeReq
                 InsertRequest insertRequest = (InsertRequest) writeRequestWithIndex.getWriteRequest();
                 BsonDocument document = insertRequest.getDocument();
 
-                IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(writer);
-                getCodec(document).encode(idHoldingBsonWriter, document,
-                        EncoderContext.builder().isEncodingCollectibleDocument(true).build());
-                insertedIds.put(writeRequestWithIndex.getIndex(), idHoldingBsonWriter.getId());
+                BsonValue documentId = insertedIds.compute(
+                        writeRequestWithIndex.getIndex(),
+                        (writeRequestIndex, writeRequestDocumentId) -> {
+                            IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(
+                                    writer,
+                                    // Reuse `writeRequestDocumentId` if it may have been generated
+                                    // by `IdHoldingBsonWriter` in a previous attempt.
+                                    // If its type is not `BsonObjectId`, we know it could not have been generated.
+                                    writeRequestDocumentId instanceof BsonObjectId ? writeRequestDocumentId.asObjectId() : null);
+                            getCodec(document).encode(idHoldingBsonWriter, document,
+                                    EncoderContext.builder().isEncodingCollectibleDocument(true).build());
+                            return idHoldingBsonWriter.getId();
+                        });
+                if (documentId == null) {
+                    // we must add an entry anyway because we rely on all the indexes being present
+                    insertedIds.put(writeRequestWithIndex.getIndex(), null);
+                }
             } else if (writeRequestWithIndex.getType() == WriteRequest.Type.UPDATE
                     || writeRequestWithIndex.getType() == WriteRequest.Type.REPLACE) {
                 UpdateRequest update = (UpdateRequest) writeRequestWithIndex.getWriteRequest();
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy
index 451545632d4..f603576ecfb 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy
@@ -32,11 +32,12 @@ import static org.bson.BsonHelper.documentWithValuesOfEveryType
 import static org.bson.BsonHelper.getBsonValues
 
 class IdHoldingBsonWriterSpecification extends Specification {
+    private static final OBJECT_ID = new BsonObjectId()
 
     def 'should write all types'() {
         given:
         def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
-        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
+        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
         def document = documentWithValuesOfEveryType()
 
         when:
@@ -47,18 +48,25 @@ class IdHoldingBsonWriterSpecification extends Specification {
         !document.containsKey('_id')
         encodedDocument.containsKey('_id')
         idTrackingBsonWriter.getId() == encodedDocument.get('_id')
+        if (expectedIdNullIfMustBeGenerated != null) {
+            idTrackingBsonWriter.getId() == expectedIdNullIfMustBeGenerated
+        }
 
         when:
         encodedDocument.remove('_id')
 
         then:
         encodedDocument == document
+
+        where:
+        fallbackId << [null, OBJECT_ID]
+        expectedIdNullIfMustBeGenerated << [null, OBJECT_ID]
     }
 
     def 'should support all types for _id value'() {
         given:
         def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
-        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
+        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
         def document = new BsonDocument()
         document.put('_id', id)
 
@@ -71,12 +79,15 @@ class IdHoldingBsonWriterSpecification extends Specification {
         idTrackingBsonWriter.getId() == id
 
         where:
-        id << getBsonValues()
+        [id, fallbackId] << [
+                getBsonValues(),
+                [null, new BsonObjectId()]
+        ].combinations()
     }
 
     def 'serialize document with list of documents that contain an _id field'() {
         def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
-        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
+        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
         def document = new BsonDocument('_id', new BsonObjectId())
                 .append('items', new BsonArray(Collections.singletonList(new BsonDocument('_id', new BsonObjectId()))))
 
@@ -86,11 +97,14 @@ class IdHoldingBsonWriterSpecification extends Specification {
 
         then:
         encodedDocument == document
+
+        where:
+        fallbackId << [null, new BsonObjectId()]
     }
 
     def 'serialize _id documents containing arrays'() {
         def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
-        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
+        def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
         BsonDocument document = BsonDocument.parse(json)
 
         when:
@@ -102,7 +116,8 @@ class IdHoldingBsonWriterSpecification extends Specification {
         encodedDocument == document
 
         where:
-        json << ['{"_id": {"a": []}, "b": 123}',
+        [json, fallbackId] << [
+                ['{"_id": {"a": []}, "b": 123}',
                  '{"_id": {"a": [1, 2]}, "b": 123}',
                  '{"_id": {"a": [[[[1]]]]}, "b": 123}',
                  '{"_id": {"a": [{"a": [1, 2]}]}, "b": 123}',
@@ -112,7 +127,9 @@ class IdHoldingBsonWriterSpecification extends Specification {
                  '{"_id": [1, 2], "b": 123}',
                  '{"_id": [[1], [[2]]], "b": 123}',
                  '{"_id": [{"a": 1}], "b": 123}',
-                 '{"_id": [{"a": [{"b": 123}]}]}']
+                 '{"_id": [{"a": [{"b": 123}]}]}'],
+                [null, new BsonObjectId()]
+        ].combinations()
     }
 
     private static BsonDocument getEncodedDocument(BsonOutput buffer) {

From 8bdaec2fcb5149d6c4dc4a0a12ced9de2537ac8c Mon Sep 17 00:00:00 2001
From: Valentin Kovalenko <valentin.kovalenko@mongodb.com>
Date: Thu, 15 Aug 2024 14:15:23 -0600
Subject: [PATCH 2/4] Add a test

---
 .../com/mongodb/client/CrudProseTest.java     | 65 +++++++++++++++++++
 1 file changed, 65 insertions(+)

diff --git a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
index b8d94cfe067..e62b387bcb7 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
@@ -19,21 +19,36 @@
 import com.mongodb.MongoBulkWriteException;
 import com.mongodb.MongoWriteConcernException;
 import com.mongodb.MongoWriteException;
+import com.mongodb.ServerAddress;
 import com.mongodb.client.model.CreateCollectionOptions;
 import com.mongodb.client.model.Filters;
 import com.mongodb.client.model.ValidationOptions;
+import com.mongodb.client.result.InsertOneResult;
+import com.mongodb.event.CommandListener;
+import com.mongodb.event.CommandStartedEvent;
 import org.bson.BsonArray;
 import org.bson.BsonDocument;
 import org.bson.BsonInt32;
 import org.bson.BsonString;
+import org.bson.BsonValue;
 import org.bson.Document;
+import org.bson.codecs.pojo.PojoCodecProvider;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
 import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet;
 import static com.mongodb.ClusterFixture.serverVersionAtLeast;
+import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry;
+import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder;
 import static java.lang.String.format;
 import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
+import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -114,6 +129,45 @@ public void testWriteErrorDetailsIsPropagated() {
         }
     }
 
+    /**
+     * This test is not from the specification.
+     */
+    @Test
+    @SuppressWarnings("try")
+    void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException {
+        ServerAddress primaryServerAddress = Fixture.getPrimary();
+        CompletableFuture<BsonValue> futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>();
+        CommandListener commandListener = new CommandListener() {
+            @Override
+            public void commandStarted(final CommandStartedEvent event) {
+                if (event.getCommandName().equals("insert")) {
+                    futureIdGeneratedByFirstInsertAttempt.complete(event.getCommand().getArray("documents").get(0).asDocument().get("_id"));
+                }
+            }
+        };
+        BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
+                .append("mode", new BsonDocument("times", new BsonInt32(1)))
+                .append("data", new BsonDocument()
+                        .append("failCommands", new BsonArray(singletonList(new BsonString("insert"))))
+                        .append("errorLabels", new BsonArray(singletonList(new BsonString("RetryableWriteError"))))
+                        .append("writeConcernError", new BsonDocument("code", new BsonInt32(91))
+                                .append("errmsg", new BsonString("Replication is being shut down"))));
+        try (MongoClient client = MongoClients.create(getMongoClientSettingsBuilder()
+                .retryWrites(true)
+                .addCommandListener(commandListener)
+                .applyToServerSettings(builder -> builder.heartbeatFrequency(50, TimeUnit.MILLISECONDS))
+                .build());
+             FailPoint ignored = FailPoint.enable(failPointDocument, primaryServerAddress)) {
+            MongoCollection<MyDocument> coll = client.getDatabase(database.getName())
+                    .getCollection(collection.getNamespace().getCollectionName(), MyDocument.class)
+                    .withCodecRegistry(fromRegistries(
+                            getDefaultCodecRegistry(),
+                            fromProviders(PojoCodecProvider.builder().automatic(true).build())));
+            InsertOneResult result = coll.insertOne(new MyDocument());
+            assertEquals(futureIdGeneratedByFirstInsertAttempt.get(), result.getInsertedId());
+        }
+    }
+
     private void setFailPoint() {
         failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
                 .append("mode", new BsonDocument("times", new BsonInt32(1)))
@@ -130,4 +184,15 @@ private void setFailPoint() {
     private void disableFailPoint() {
         getCollectionHelper().runAdminCommand(failPointDocument.append("mode", new BsonString("off")));
     }
+
+    public static final class MyDocument {
+        private int v;
+
+        public MyDocument() {
+        }
+
+        public int getV() {
+            return v;
+        }
+    }
 }

From f0b26cce3a5189896627dfb4f66bc19bab5f2ee8 Mon Sep 17 00:00:00 2001
From: Valentin Kovalenko <valentin.kovalenko@mongodb.com>
Date: Thu, 15 Aug 2024 14:23:24 -0600
Subject: [PATCH 3/4] Improve the test

---
 .../com/mongodb/client/CrudProseTest.java         | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
index e62b387bcb7..b2c29d8dfc3 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
@@ -23,7 +23,6 @@
 import com.mongodb.client.model.CreateCollectionOptions;
 import com.mongodb.client.model.Filters;
 import com.mongodb.client.model.ValidationOptions;
-import com.mongodb.client.result.InsertOneResult;
 import com.mongodb.event.CommandListener;
 import com.mongodb.event.CommandStartedEvent;
 import org.bson.BsonArray;
@@ -137,11 +136,17 @@ public void testWriteErrorDetailsIsPropagated() {
     void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException {
         ServerAddress primaryServerAddress = Fixture.getPrimary();
         CompletableFuture<BsonValue> futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>();
+        CompletableFuture<BsonValue> futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>();
         CommandListener commandListener = new CommandListener() {
             @Override
             public void commandStarted(final CommandStartedEvent event) {
                 if (event.getCommandName().equals("insert")) {
-                    futureIdGeneratedByFirstInsertAttempt.complete(event.getCommand().getArray("documents").get(0).asDocument().get("_id"));
+                    BsonValue generatedId = event.getCommand().getArray("documents").get(0).asDocument().get("_id");
+                    if (!futureIdGeneratedByFirstInsertAttempt.isDone()) {
+                        futureIdGeneratedByFirstInsertAttempt.complete(generatedId);
+                    } else {
+                        futureIdGeneratedBySecondInsertAttempt.complete(generatedId);
+                    }
                 }
             }
         };
@@ -163,8 +168,10 @@ public void commandStarted(final CommandStartedEvent event) {
                     .withCodecRegistry(fromRegistries(
                             getDefaultCodecRegistry(),
                             fromProviders(PojoCodecProvider.builder().automatic(true).build())));
-            InsertOneResult result = coll.insertOne(new MyDocument());
-            assertEquals(futureIdGeneratedByFirstInsertAttempt.get(), result.getInsertedId());
+            BsonValue insertedId = coll.insertOne(new MyDocument()).getInsertedId();
+            BsonValue idGeneratedByFirstInsertAttempt = futureIdGeneratedByFirstInsertAttempt.get();
+            assertEquals(idGeneratedByFirstInsertAttempt, insertedId);
+            assertEquals(idGeneratedByFirstInsertAttempt, futureIdGeneratedBySecondInsertAttempt.get());
         }
     }
 

From 23eb7d9729d6389027cdb7fddc0887b3dda8fa9b Mon Sep 17 00:00:00 2001
From: Valentin Kovalenko <valentin.kovalenko@mongodb.com>
Date: Thu, 15 Aug 2024 15:37:57 -0600
Subject: [PATCH 4/4] Add the `isDiscoverableReplicaSet` assumption to the test

---
 .../src/test/functional/com/mongodb/client/CrudProseTest.java    | 1 +
 1 file changed, 1 insertion(+)

diff --git a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
index b2c29d8dfc3..5d3907bb210 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java
@@ -134,6 +134,7 @@ public void testWriteErrorDetailsIsPropagated() {
     @Test
     @SuppressWarnings("try")
     void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException {
+        assumeTrue(isDiscoverableReplicaSet());
         ServerAddress primaryServerAddress = Fixture.getPrimary();
         CompletableFuture<BsonValue> futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>();
         CompletableFuture<BsonValue> futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>();