Skip to content

Commit 92e6150

Browse files
authored
PYTHON-3493 Bulk Write InsertOne Should Be Parameter Of Collection Type (#1106)
1 parent 133c55d commit 92e6150

17 files changed

+144
-38
lines changed

doc/examples/type_hints.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ These methods automatically add an "_id" field.
113113
>>> assert result is not None
114114
>>> assert result["year"] == 1993
115115
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
116+
>>> assert result["_id"] # type:ignore[typeddict-item]
117+
118+
This same typing scheme works for all of the insert methods (:meth:`~pymongo.collection.Collection.insert_one`,
119+
:meth:`~pymongo.collection.Collection.insert_many`, and :meth:`~pymongo.collection.Collection.bulk_write`).
120+
For `bulk_write` both :class:`~pymongo.operations.InsertOne` and :class:`~pymongo.operations.ReplaceOne` operators are generic.
121+
122+
.. doctest::
123+
:pyversion: >= 3.8
124+
125+
>>> from typing import TypedDict
126+
>>> from pymongo import MongoClient
127+
>>> from pymongo.operations import InsertOne
128+
>>> from pymongo.collection import Collection
129+
>>> client: MongoClient = MongoClient()
130+
>>> collection: Collection[Movie] = client.test.test
131+
>>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))])
132+
>>> result = collection.find_one({"name": "Jurassic Park"})
133+
>>> assert result is not None
134+
>>> assert result["year"] == 1993
135+
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
116136
>>> assert result["_id"] # type:ignore[typeddict-item]
117137

118138
Modeling Document Types with TypedDict

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ignore_missing_imports = True
3333
ignore_missing_imports = True
3434

3535
[mypy-test.test_mypy]
36-
warn_unused_ignores = false
36+
warn_unused_ignores = True
3737

3838
[mypy-winkerberos.*]
3939
ignore_missing_imports = True

pymongo/collection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,14 @@
7777
_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1}
7878

7979

80-
_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany]
80+
_WriteOp = Union[
81+
InsertOne[_DocumentType],
82+
DeleteOne,
83+
DeleteMany,
84+
ReplaceOne[_DocumentType],
85+
UpdateOne,
86+
UpdateMany,
87+
]
8188
# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)]
8289
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
8390
_IndexKeyHint = Union[str, _IndexList]
@@ -436,7 +443,7 @@ def with_options(
436443
@_csot.apply
437444
def bulk_write(
438445
self,
439-
requests: Sequence[_WriteOp],
446+
requests: Sequence[_WriteOp[_DocumentType]],
440447
ordered: bool = True,
441448
bypass_document_validation: bool = False,
442449
session: Optional["ClientSession"] = None,

pymongo/encryption.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import enum
1919
import socket
2020
import weakref
21-
from typing import Any, Mapping, Optional, Sequence
21+
from typing import Any, Generic, Mapping, Optional, Sequence
2222

2323
try:
2424
from pymongocrypt.auto_encrypter import AutoEncrypter
@@ -55,6 +55,7 @@
5555
from pymongo.read_concern import ReadConcern
5656
from pymongo.results import BulkWriteResult, DeleteResult
5757
from pymongo.ssl_support import get_ssl_context
58+
from pymongo.typings import _DocumentType
5859
from pymongo.uri_parser import parse_host
5960
from pymongo.write_concern import WriteConcern
6061

@@ -430,7 +431,7 @@ class QueryType(str, enum.Enum):
430431
"""Used to encrypt a value for an equality query."""
431432

432433

433-
class ClientEncryption(object):
434+
class ClientEncryption(Generic[_DocumentType]):
434435
"""Explicit client-side field level encryption."""
435436

436437
def __init__(

pymongo/operations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
# limitations under the License.
1414

1515
"""Operation class definitions."""
16-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
16+
from typing import Any, Dict, Generic, List, Mapping, Optional, Sequence, Tuple, Union
1717

18+
from bson.raw_bson import RawBSONDocument
1819
from pymongo import helpers
1920
from pymongo.collation import validate_collation_or_none
2021
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
2122
from pymongo.helpers import _gen_index_name, _index_document, _index_list
22-
from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline
23+
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
2324

2425

25-
class InsertOne(object):
26+
class InsertOne(Generic[_DocumentType]):
2627
"""Represents an insert_one operation."""
2728

2829
__slots__ = ("_doc",)
2930

30-
def __init__(self, document: _DocumentIn) -> None:
31+
def __init__(self, document: Union[_DocumentType, RawBSONDocument]) -> None:
3132
"""Create an InsertOne instance.
3233
3334
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@@ -170,15 +171,15 @@ def __ne__(self, other: Any) -> bool:
170171
return not self == other
171172

172173

173-
class ReplaceOne(object):
174+
class ReplaceOne(Generic[_DocumentType]):
174175
"""Represents a replace_one operation."""
175176

176177
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
177178

178179
def __init__(
179180
self,
180181
filter: Mapping[str, Any],
181-
replacement: Mapping[str, Any],
182+
replacement: Union[_DocumentType, RawBSONDocument],
182183
upsert: bool = False,
183184
collation: Optional[_CollationIn] = None,
184185
hint: Optional[_IndexKeyHint] = None,

pymongo/typings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@
3737
_Pipeline = Sequence[Mapping[str, Any]]
3838
_DocumentOut = _DocumentIn
3939
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
40+
41+
42+
def strip_optional(elem):
43+
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
44+
while inside a list comprehension."""
45+
assert elem is not None
46+
return elem

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None:
10901090
class IntegrationTest(PyMongoTestCase):
10911091
"""Base class for TestCases that need a connection to MongoDB to pass."""
10921092

1093-
client: MongoClient
1093+
client: MongoClient[dict]
10941094
db: Database
10951095
credentials: Dict[str, str]
10961096

test/mockupdb/test_cluster_time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def callback(client):
6060
self.cluster_time_conversation(callback, [{"ok": 1}] * 2)
6161

6262
def test_bulk(self):
63-
def callback(client):
63+
def callback(client: MongoClient[dict]) -> None:
6464
client.db.collection.bulk_write(
6565
[InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})]
6666
)

test/mockupdb/test_op_msg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,22 @@
137137
# Legacy methods
138138
Operation(
139139
"bulk_write_insert",
140-
lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]),
140+
lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]),
141141
request=OpMsg({"insert": "coll"}, flags=0),
142142
reply={"ok": 1, "n": 2},
143143
),
144144
Operation(
145145
"bulk_write_insert-w0",
146146
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
147-
[InsertOne({}), InsertOne({})]
147+
[InsertOne[dict]({}), InsertOne[dict]({})]
148148
),
149149
request=OpMsg({"insert": "coll"}, flags=0),
150150
reply={"ok": 1, "n": 2},
151151
),
152152
Operation(
153153
"bulk_write_insert-w0-unordered",
154154
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
155-
[InsertOne({}), InsertOne({})], ordered=False
155+
[InsertOne[dict]({}), InsertOne[dict]({})], ordered=False
156156
),
157157
request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]),
158158
reply=None,

test/test_bulk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_upsert(self):
296296
def test_numerous_inserts(self):
297297
# Ensure we don't exceed server's maxWriteBatchSize size limit.
298298
n_docs = client_context.max_write_batch_size + 100
299-
requests = [InsertOne({}) for _ in range(n_docs)]
299+
requests = [InsertOne[dict]({}) for _ in range(n_docs)]
300300
result = self.coll.bulk_write(requests, ordered=False)
301301
self.assertEqual(n_docs, result.inserted_count)
302302
self.assertEqual(n_docs, self.coll.count_documents({}))
@@ -347,7 +347,7 @@ def test_bulk_write_no_results(self):
347347

348348
def test_bulk_write_invalid_arguments(self):
349349
# The requests argument must be a list.
350-
generator = (InsertOne({}) for _ in range(10))
350+
generator = (InsertOne[dict]({}) for _ in range(10))
351351
with self.assertRaises(TypeError):
352352
self.coll.bulk_write(generator) # type: ignore[arg-type]
353353

test/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,7 @@ def test_network_error_message(self):
16521652
with self.fail_point(
16531653
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
16541654
):
1655+
assert client.address is not None
16551656
expected = "%s:%s: " % client.address
16561657
with self.assertRaisesRegex(AutoReconnect, expected):
16571658
client.pymongo_test.test.find_one({})

test/test_database.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import re
1818
import sys
19-
from typing import Any, Iterable, List, Mapping
19+
from typing import Any, Iterable, List, Mapping, Union
2020

2121
sys.path[0:0] = [""]
2222

@@ -201,7 +201,7 @@ def test_list_collection_names_filter(self):
201201
db.capped.insert_one({})
202202
db.non_capped.insert_one({})
203203
self.addCleanup(client.drop_database, db.name)
204-
204+
filter: Union[None, dict]
205205
# Should not send nameOnly.
206206
for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}):
207207
results.clear()
@@ -210,7 +210,6 @@ def test_list_collection_names_filter(self):
210210
self.assertNotIn("nameOnly", results["started"][0].command)
211211

212212
# Should send nameOnly (except on 2.6).
213-
filter: Any
214213
for filter in (None, {}, {"name": {"$in": ["capped", "non_capped"]}}):
215214
results.clear()
216215
names = db.list_collection_names(filter=filter)

test/test_mypy.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import tempfile
1919
import unittest
20-
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List
20+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union
2121

2222
try:
2323
from typing_extensions import NotRequired, TypedDict
@@ -42,7 +42,7 @@ class ImplicitMovie(TypedDict):
4242
Movie = dict # type:ignore[misc,assignment]
4343
ImplicitMovie = dict # type: ignore[assignment,misc]
4444
MovieWithId = dict # type: ignore[assignment,misc]
45-
TypedDict = None # type: ignore[assignment]
45+
TypedDict = None
4646
NotRequired = None # type: ignore[assignment]
4747

4848

@@ -59,7 +59,7 @@ class ImplicitMovie(TypedDict):
5959
from bson.son import SON
6060
from pymongo import ASCENDING, MongoClient
6161
from pymongo.collection import Collection
62-
from pymongo.operations import InsertOne
62+
from pymongo.operations import DeleteOne, InsertOne, ReplaceOne
6363
from pymongo.read_preferences import ReadPreference
6464

6565
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails")
@@ -124,11 +124,40 @@ def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
124124
docs = to_list(cursor)
125125
self.assertTrue(docs)
126126

127+
@only_type_check
127128
def test_bulk_write(self) -> None:
128129
self.coll.insert_one({})
129-
requests = [InsertOne({})]
130-
result = self.coll.bulk_write(requests)
131-
self.assertTrue(result.acknowledged)
130+
coll: Collection[Movie] = self.coll
131+
requests: List[InsertOne[Movie]] = [InsertOne(Movie(name="American Graffiti", year=1973))]
132+
self.assertTrue(coll.bulk_write(requests).acknowledged)
133+
new_requests: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = []
134+
input_list: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = [
135+
InsertOne(Movie(name="American Graffiti", year=1973)),
136+
ReplaceOne({}, Movie(name="American Graffiti", year=1973)),
137+
]
138+
for i in input_list:
139+
new_requests.append(i)
140+
self.assertTrue(coll.bulk_write(new_requests).acknowledged)
141+
142+
# Because ReplaceOne is not generic, type checking is not enforced for ReplaceOne in the first example.
143+
@only_type_check
144+
def test_bulk_write_heterogeneous(self):
145+
coll: Collection[Movie] = self.coll
146+
requests: List[Union[InsertOne[Movie], ReplaceOne, DeleteOne]] = [
147+
InsertOne(Movie(name="American Graffiti", year=1973)),
148+
ReplaceOne({}, {"name": "American Graffiti", "year": "WRONG_TYPE"}),
149+
DeleteOne({}),
150+
]
151+
self.assertTrue(coll.bulk_write(requests).acknowledged)
152+
requests_two: List[Union[InsertOne[Movie], ReplaceOne[Movie], DeleteOne]] = [
153+
InsertOne(Movie(name="American Graffiti", year=1973)),
154+
ReplaceOne(
155+
{},
156+
{"name": "American Graffiti", "year": "WRONG_TYPE"}, # type:ignore[typeddict-item]
157+
),
158+
DeleteOne({}),
159+
]
160+
self.assertTrue(coll.bulk_write(requests_two).acknowledged)
132161

133162
def test_command(self) -> None:
134163
result: Dict = self.client.admin.command("ping")
@@ -340,6 +369,40 @@ def test_typeddict_document_type_insertion(self) -> None:
340369
)
341370
coll.insert_many([bad_movie])
342371

372+
@only_type_check
373+
def test_bulk_write_document_type_insertion(self):
374+
client: MongoClient[MovieWithId] = MongoClient()
375+
coll: Collection[MovieWithId] = client.test.test
376+
coll.bulk_write(
377+
[InsertOne(Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type]
378+
)
379+
mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971}
380+
coll.bulk_write(
381+
[InsertOne(mov_dict)] # type:ignore[arg-type]
382+
)
383+
coll.bulk_write(
384+
[
385+
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971})
386+
] # No error because it is in-line.
387+
)
388+
389+
@only_type_check
390+
def test_bulk_write_document_type_replacement(self):
391+
client: MongoClient[MovieWithId] = MongoClient()
392+
coll: Collection[MovieWithId] = client.test.test
393+
coll.bulk_write(
394+
[ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type]
395+
)
396+
mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971}
397+
coll.bulk_write(
398+
[ReplaceOne({}, mov_dict)] # type:ignore[arg-type]
399+
)
400+
coll.bulk_write(
401+
[
402+
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971})
403+
] # No error because it is in-line.
404+
)
405+
343406
@only_type_check
344407
def test_typeddict_explicit_document_type(self) -> None:
345408
out = MovieWithId(_id=ObjectId(), name="THX-1138", year=1971)

test/test_server_selection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymongo.server_selectors import writable_server_selector
2424
from pymongo.settings import TopologySettings
2525
from pymongo.topology import Topology
26+
from pymongo.typings import strip_optional
2627

2728
sys.path[0:0] = [""]
2829

@@ -85,7 +86,10 @@ def all_hosts_started():
8586
)
8687

8788
wait_until(all_hosts_started, "receive heartbeat from all hosts")
88-
expected_port = max([n.address[1] for n in client._topology._description.readable_servers])
89+
90+
expected_port = max(
91+
[strip_optional(n.address[1]) for n in client._topology._description.readable_servers]
92+
)
8993

9094
# Insert 1 record and access it 10 times.
9195
coll.insert_one({"name": "John Doe"})

test/test_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,9 @@ def _test_writes(self, op):
898898

899899
@client_context.require_no_standalone
900900
def test_writes(self):
901-
self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session))
901+
self._test_writes(
902+
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
903+
)
902904
self._test_writes(lambda coll, session: coll.insert_one({}, session=session))
903905
self._test_writes(lambda coll, session: coll.insert_many([{}], session=session))
904906
self._test_writes(
@@ -944,7 +946,7 @@ def _test_no_read_concern(self, op):
944946
@client_context.require_no_standalone
945947
def test_writes_do_not_include_read_concern(self):
946948
self._test_no_read_concern(
947-
lambda coll, session: coll.bulk_write([InsertOne({})], session=session)
949+
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
948950
)
949951
self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session))
950952
self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session))

0 commit comments

Comments
 (0)