Skip to content

Commit 247d7b4

Browse files
authored
MixedBulkWriteOperation should generate inserted document IDs at most once per batch (#1482) (#1484)
This is a backport of #1482 to `4.11.x` JAVA-5572
1 parent 1f425c4 commit 247d7b4

File tree

4 files changed

+134
-13
lines changed

4 files changed

+134
-13
lines changed

driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java

+18-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.mongodb.internal.connection;
1818

19+
import com.mongodb.lang.Nullable;
1920
import org.bson.BsonBinary;
2021
import org.bson.BsonBinaryWriter;
2122
import org.bson.BsonBoolean;
@@ -57,11 +58,17 @@ public class IdHoldingBsonWriter extends LevelCountingBsonWriter {
5758
private LevelCountingBsonWriter idBsonBinaryWriter;
5859
private BasicOutputBuffer outputBuffer;
5960
private String currentFieldName;
61+
private final BsonValue fallbackId;
6062
private BsonValue id;
6163
private boolean idFieldIsAnArray = false;
6264

63-
public IdHoldingBsonWriter(final BsonWriter bsonWriter) {
65+
/**
66+
* @param fallbackId The "_id" field value to use if the top-level document written via this {@link BsonWriter}
67+
* does not have "_id". If {@code null}, then a new {@link BsonObjectId} is generated instead.
68+
*/
69+
public IdHoldingBsonWriter(final BsonWriter bsonWriter, @Nullable final BsonObjectId fallbackId) {
6470
super(bsonWriter);
71+
this.fallbackId = fallbackId;
6572
}
6673

6774
@Override
@@ -99,7 +106,7 @@ public void writeEndDocument() {
99106
}
100107

101108
if (getCurrentLevel() == 0 && id == null) {
102-
id = new BsonObjectId();
109+
id = fallbackId == null ? new BsonObjectId() : fallbackId;
103110
writeObjectId(ID_FIELD_NAME, id.asObjectId().getValue());
104111
}
105112
super.writeEndDocument();
@@ -408,6 +415,15 @@ public void flush() {
408415
super.flush();
409416
}
410417

418+
/**
419+
* Returns either the value of the "_id" field from the top-level document written via this {@link BsonWriter},
420+
* provided that the document is not {@link RawBsonDocument},
421+
* or the generated {@link BsonObjectId}.
422+
* If the document is {@link RawBsonDocument}, then returns {@code null}.
423+
* <p>
424+
* {@linkplain #flush() Flushing} is not required before calling this method.</p>
425+
*/
426+
@Nullable
411427
public BsonValue getId() {
412428
return id;
413429
}

driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java

+18-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import com.mongodb.internal.bulk.WriteRequestWithIndex;
2424
import org.bson.BsonDocument;
2525
import org.bson.BsonDocumentWrapper;
26+
import org.bson.BsonObjectId;
2627
import org.bson.BsonValue;
2728
import org.bson.BsonWriter;
2829
import org.bson.codecs.BsonValueCodecProvider;
@@ -191,10 +192,23 @@ public void encode(final BsonWriter writer, final WriteRequestWithIndex writeReq
191192
InsertRequest insertRequest = (InsertRequest) writeRequestWithIndex.getWriteRequest();
192193
BsonDocument document = insertRequest.getDocument();
193194

194-
IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(writer);
195-
getCodec(document).encode(idHoldingBsonWriter, document,
196-
EncoderContext.builder().isEncodingCollectibleDocument(true).build());
197-
insertedIds.put(writeRequestWithIndex.getIndex(), idHoldingBsonWriter.getId());
195+
BsonValue documentId = insertedIds.compute(
196+
writeRequestWithIndex.getIndex(),
197+
(writeRequestIndex, writeRequestDocumentId) -> {
198+
IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(
199+
writer,
200+
// Reuse `writeRequestDocumentId` if it may have been generated
201+
// by `IdHoldingBsonWriter` in a previous attempt.
202+
// If its type is not `BsonObjectId`, we know it could not have been generated.
203+
writeRequestDocumentId instanceof BsonObjectId ? writeRequestDocumentId.asObjectId() : null);
204+
getCodec(document).encode(idHoldingBsonWriter, document,
205+
EncoderContext.builder().isEncodingCollectibleDocument(true).build());
206+
return idHoldingBsonWriter.getId();
207+
});
208+
if (documentId == null) {
209+
// we must add an entry anyway because we rely on all the indexes being present
210+
insertedIds.put(writeRequestWithIndex.getIndex(), null);
211+
}
198212
} else if (writeRequestWithIndex.getType() == WriteRequest.Type.UPDATE
199213
|| writeRequestWithIndex.getType() == WriteRequest.Type.REPLACE) {
200214
UpdateRequest update = (UpdateRequest) writeRequestWithIndex.getWriteRequest();

driver-core/src/test/unit/com/mongodb/internal/connection/IdHoldingBsonWriterSpecification.groovy

+24-7
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ import static org.bson.BsonHelper.documentWithValuesOfEveryType
3232
import static org.bson.BsonHelper.getBsonValues
3333

3434
class IdHoldingBsonWriterSpecification extends Specification {
35+
private static final OBJECT_ID = new BsonObjectId()
3536

3637
def 'should write all types'() {
3738
given:
3839
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
39-
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
40+
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
4041
def document = documentWithValuesOfEveryType()
4142

4243
when:
@@ -47,18 +48,25 @@ class IdHoldingBsonWriterSpecification extends Specification {
4748
!document.containsKey('_id')
4849
encodedDocument.containsKey('_id')
4950
idTrackingBsonWriter.getId() == encodedDocument.get('_id')
51+
if (expectedIdNullIfMustBeGenerated != null) {
52+
idTrackingBsonWriter.getId() == expectedIdNullIfMustBeGenerated
53+
}
5054

5155
when:
5256
encodedDocument.remove('_id')
5357

5458
then:
5559
encodedDocument == document
60+
61+
where:
62+
fallbackId << [null, OBJECT_ID]
63+
expectedIdNullIfMustBeGenerated << [null, OBJECT_ID]
5664
}
5765

5866
def 'should support all types for _id value'() {
5967
given:
6068
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
61-
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
69+
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
6270
def document = new BsonDocument()
6371
document.put('_id', id)
6472

@@ -71,12 +79,15 @@ class IdHoldingBsonWriterSpecification extends Specification {
7179
idTrackingBsonWriter.getId() == id
7280

7381
where:
74-
id << getBsonValues()
82+
[id, fallbackId] << [
83+
getBsonValues(),
84+
[null, new BsonObjectId()]
85+
].combinations()
7586
}
7687

7788
def 'serialize document with list of documents that contain an _id field'() {
7889
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
79-
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
90+
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
8091
def document = new BsonDocument('_id', new BsonObjectId())
8192
.append('items', new BsonArray(Collections.singletonList(new BsonDocument('_id', new BsonObjectId()))))
8293

@@ -86,11 +97,14 @@ class IdHoldingBsonWriterSpecification extends Specification {
8697

8798
then:
8899
encodedDocument == document
100+
101+
where:
102+
fallbackId << [null, new BsonObjectId()]
89103
}
90104

91105
def 'serialize _id documents containing arrays'() {
92106
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
93-
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
107+
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
94108
BsonDocument document = BsonDocument.parse(json)
95109

96110
when:
@@ -102,7 +116,8 @@ class IdHoldingBsonWriterSpecification extends Specification {
102116
encodedDocument == document
103117

104118
where:
105-
json << ['{"_id": {"a": []}, "b": 123}',
119+
[json, fallbackId] << [
120+
['{"_id": {"a": []}, "b": 123}',
106121
'{"_id": {"a": [1, 2]}, "b": 123}',
107122
'{"_id": {"a": [[[[1]]]]}, "b": 123}',
108123
'{"_id": {"a": [{"a": [1, 2]}]}, "b": 123}',
@@ -112,7 +127,9 @@ class IdHoldingBsonWriterSpecification extends Specification {
112127
'{"_id": [1, 2], "b": 123}',
113128
'{"_id": [[1], [[2]]], "b": 123}',
114129
'{"_id": [{"a": 1}], "b": 123}',
115-
'{"_id": [{"a": [{"b": 123}]}]}']
130+
'{"_id": [{"a": [{"b": 123}]}]}'],
131+
[null, new BsonObjectId()]
132+
].combinations()
116133
}
117134

118135
private static BsonDocument getEncodedDocument(BsonOutput buffer) {

driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java

+74
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,35 @@
1919
import com.mongodb.MongoBulkWriteException;
2020
import com.mongodb.MongoWriteConcernException;
2121
import com.mongodb.MongoWriteException;
22+
import com.mongodb.ServerAddress;
2223
import com.mongodb.client.model.CreateCollectionOptions;
2324
import com.mongodb.client.model.Filters;
2425
import com.mongodb.client.model.ValidationOptions;
26+
import com.mongodb.event.CommandListener;
27+
import com.mongodb.event.CommandStartedEvent;
2528
import org.bson.BsonArray;
2629
import org.bson.BsonDocument;
2730
import org.bson.BsonInt32;
2831
import org.bson.BsonString;
32+
import org.bson.BsonValue;
2933
import org.bson.Document;
34+
import org.bson.codecs.pojo.PojoCodecProvider;
3035
import org.junit.Before;
3136
import org.junit.Test;
3237

38+
import java.util.concurrent.CompletableFuture;
39+
import java.util.concurrent.ExecutionException;
40+
import java.util.concurrent.TimeUnit;
41+
3342
import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet;
3443
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
44+
import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry;
45+
import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder;
3546
import static java.lang.String.format;
3647
import static java.util.Arrays.asList;
48+
import static java.util.Collections.singletonList;
49+
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
50+
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
3751
import static org.junit.Assert.assertEquals;
3852
import static org.junit.Assert.assertFalse;
3953
import static org.junit.Assert.assertNotNull;
@@ -116,6 +130,55 @@ public void testWriteErrorDetailsIsPropagated() {
116130
}
117131
}
118132

133+
/**
134+
* This test is not from the specification.
135+
*/
136+
@Test
137+
@SuppressWarnings("try")
138+
public void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException {
139+
assumeTrue(serverVersionAtLeast(4, 0));
140+
assumeTrue(isDiscoverableReplicaSet());
141+
ServerAddress primaryServerAddress = Fixture.getPrimary();
142+
CompletableFuture<BsonValue> futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>();
143+
CompletableFuture<BsonValue> futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>();
144+
CommandListener commandListener = new CommandListener() {
145+
@Override
146+
public void commandStarted(final CommandStartedEvent event) {
147+
if (event.getCommandName().equals("insert")) {
148+
BsonValue generatedId = event.getCommand().getArray("documents").get(0).asDocument().get("_id");
149+
if (!futureIdGeneratedByFirstInsertAttempt.isDone()) {
150+
futureIdGeneratedByFirstInsertAttempt.complete(generatedId);
151+
} else {
152+
futureIdGeneratedBySecondInsertAttempt.complete(generatedId);
153+
}
154+
}
155+
}
156+
};
157+
BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
158+
.append("mode", new BsonDocument("times", new BsonInt32(1)))
159+
.append("data", new BsonDocument()
160+
.append("failCommands", new BsonArray(singletonList(new BsonString("insert"))))
161+
.append("errorLabels", new BsonArray(singletonList(new BsonString("RetryableWriteError"))))
162+
.append("writeConcernError", new BsonDocument("code", new BsonInt32(91))
163+
.append("errmsg", new BsonString("Replication is being shut down"))));
164+
try (MongoClient client = MongoClients.create(getMongoClientSettingsBuilder()
165+
.retryWrites(true)
166+
.addCommandListener(commandListener)
167+
.applyToServerSettings(builder -> builder.heartbeatFrequency(50, TimeUnit.MILLISECONDS))
168+
.build());
169+
FailPoint ignored = FailPoint.enable(failPointDocument, primaryServerAddress)) {
170+
MongoCollection<MyDocument> coll = client.getDatabase(database.getName())
171+
.getCollection(collection.getNamespace().getCollectionName(), MyDocument.class)
172+
.withCodecRegistry(fromRegistries(
173+
getDefaultCodecRegistry(),
174+
fromProviders(PojoCodecProvider.builder().automatic(true).build())));
175+
BsonValue insertedId = coll.insertOne(new MyDocument()).getInsertedId();
176+
BsonValue idGeneratedByFirstInsertAttempt = futureIdGeneratedByFirstInsertAttempt.get();
177+
assertEquals(idGeneratedByFirstInsertAttempt, insertedId);
178+
assertEquals(idGeneratedByFirstInsertAttempt, futureIdGeneratedBySecondInsertAttempt.get());
179+
}
180+
}
181+
119182
private void setFailPoint() {
120183
failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
121184
.append("mode", new BsonDocument("times", new BsonInt32(1)))
@@ -132,4 +195,15 @@ private void setFailPoint() {
132195
private void disableFailPoint() {
133196
getCollectionHelper().runAdminCommand(failPointDocument.append("mode", new BsonString("off")));
134197
}
198+
199+
public static final class MyDocument {
200+
private int v;
201+
202+
public MyDocument() {
203+
}
204+
205+
public int getV() {
206+
return v;
207+
}
208+
}
135209
}

0 commit comments

Comments
 (0)